diff --git a/connection.go b/connection.go index 813f229..de2257b 100644 --- a/connection.go +++ b/connection.go @@ -73,15 +73,19 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e // Ping attempts to verify that the server is accessible. // Returns ErrBadConn if ping fails and consequently DB.Ping will remove the conn from the pool. func (c *conn) Ping(ctx context.Context) error { - log := logger.WithContext(c.id, driverctx.CorrelationIdFromContext(ctx), "") ctx = driverctx.NewContextWithConnId(ctx, c.id) + log, _ := client.LoggerAndContext(ctx, nil) + log.Debug().Msg("databricks: pinging") + ctx1, cancel := context.WithTimeout(ctx, c.cfg.PingTimeout) defer cancel() _, err := c.QueryContext(ctx1, "select 1", nil) if err != nil { log.Err(err).Msg("databricks: failed to ping") - return driver.ErrBadConn + return dbsqlerrint.NewBadConnectionError(err) } + + log.Debug().Msg("databricks: ping successful") return nil } @@ -102,52 +106,21 @@ func (c *conn) IsValid() bool { // ExecContext honors the context timeout and return when it is canceled. // Statement ExecContext is the same as connection ExecContext func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { - - corrId := driverctx.CorrelationIdFromContext(ctx) - log := logger.WithContext(c.id, corrId, "") + ctx = driverctx.NewContextWithConnId(ctx, c.id) + log, _ := client.LoggerAndContext(ctx, nil) msg, start := logger.Track("ExecContext") defer log.Duration(msg, start) - ctx = driverctx.NewContextWithConnId(ctx, c.id) - + corrId := driverctx.CorrelationIdFromContext(ctx) if len(args) > 0 && c.session.ServerProtocolVersion < cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V8 { return nil, dbsqlerrint.NewDriverError(ctx, dbsqlerr.ErrParametersNotSupported, nil) } exStmtResp, opStatusResp, err := c.runQuery(ctx, query, args) + log, ctx = client.LoggerAndContext(ctx, exStmtResp) + stagingErr := c.execStagingOperation(exStmtResp, ctx) if exStmtResp != nil && exStmtResp.OperationHandle != nil { - var isStagingOperation bool - if exStmtResp.DirectResults != nil && exStmtResp.DirectResults.ResultSetMetadata != nil && exStmtResp.DirectResults.ResultSetMetadata.IsStagingOperation != nil { - isStagingOperation = *exStmtResp.DirectResults.ResultSetMetadata.IsStagingOperation - } else { - req := cli_service.TGetResultSetMetadataReq{ - OperationHandle: exStmtResp.OperationHandle, - } - resp, err := c.client.GetResultSetMetadata(ctx, &req) - if err != nil { - return nil, dbsqlerrint.NewDriverError(ctx, "error performing staging operation", err) - } - isStagingOperation = *resp.IsStagingOperation - } - if isStagingOperation { - if len(driverctx.StagingPathsFromContext(ctx)) != 0 { - row, err := rows.NewRows(c.id, corrId, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults) - if err != nil { - return nil, dbsqlerrint.NewDriverError(ctx, "error reading row.", err) - } - err = c.ExecStagingOperation(ctx, row) - if err != nil { - return nil, err - } - } else { - return nil, dbsqlerrint.NewDriverError(ctx, "staging ctx must be provided.", nil) - } - } - - // we have an operation id so update the logger - log = logger.WithContext(c.id, corrId, client.SprintGuid(exStmtResp.OperationHandle.OperationId.GUID)) - // since we have an operation handle we can close the operation if necessary alreadyClosed := exStmtResp.DirectResults != nil && exStmtResp.DirectResults.CloseOperation != nil newCtx := driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), c.id), corrId) @@ -160,171 +133,20 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name } } } + if err != nil { log.Err(err).Msgf("databricks: failed to execute query: query %s", query) return nil, dbsqlerrint.NewExecutionError(ctx, dbsqlerr.ErrQueryExecution, err, opStatusResp) } - res := result{AffectedRows: opStatusResp.GetNumModifiedRows()} - - return &res, nil -} - -func Succeeded(response *http.Response) bool { - if response.StatusCode == 200 || response.StatusCode == 201 || response.StatusCode == 202 || response.StatusCode == 204 { - return true - } - return false -} - -func (c *conn) HandleStagingPut(ctx context.Context, presignedUrl string, headers map[string]string, localFile string) dbsqlerr.DBError { - if localFile == "" { - return dbsqlerrint.NewDriverError(ctx, "cannot perform PUT without specifying a local_file", nil) - } - client := &http.Client{} - - dat, err := os.ReadFile(localFile) - - if err != nil { - return dbsqlerrint.NewDriverError(ctx, "error reading local file", err) - } - - req, _ := http.NewRequest("PUT", presignedUrl, bytes.NewReader(dat)) - - for k, v := range headers { - req.Header.Set(k, v) - } - res, err := client.Do(req) - if err != nil { - return dbsqlerrint.NewDriverError(ctx, "error sending http request", err) - } - defer res.Body.Close() - content, err := io.ReadAll(res.Body) - - if err != nil || !Succeeded(res) { - return dbsqlerrint.NewDriverError(ctx, fmt.Sprintf("staging operation over HTTP was unsuccessful: %d-%s", res.StatusCode, content), nil) - } - return nil - -} - -func (c *conn) HandleStagingGet(ctx context.Context, presignedUrl string, headers map[string]string, localFile string) dbsqlerr.DBError { - if localFile == "" { - return dbsqlerrint.NewDriverError(ctx, "cannot perform GET without specifying a local_file", nil) - } - client := &http.Client{} - req, _ := http.NewRequest("GET", presignedUrl, nil) - - for k, v := range headers { - req.Header.Set(k, v) - } - res, err := client.Do(req) - if err != nil { - return dbsqlerrint.NewDriverError(ctx, "error sending http request", err) - } - defer res.Body.Close() - content, err := io.ReadAll(res.Body) - - if err != nil || !Succeeded(res) { - return dbsqlerrint.NewDriverError(ctx, fmt.Sprintf("staging operation over HTTP was unsuccessful: %d-%s", res.StatusCode, content), nil) - } - - err = os.WriteFile(localFile, content, 0644) //nolint:gosec - if err != nil { - return dbsqlerrint.NewDriverError(ctx, "error writing local file", err) - } - return nil -} - -func (c *conn) HandleStagingDelete(ctx context.Context, presignedUrl string, headers map[string]string) dbsqlerr.DBError { - client := &http.Client{} - req, _ := http.NewRequest("DELETE", presignedUrl, nil) - for k, v := range headers { - req.Header.Set(k, v) - } - res, err := client.Do(req) - if err != nil { - return dbsqlerrint.NewDriverError(ctx, "error sending http request", err) - } - defer res.Body.Close() - content, err := io.ReadAll(res.Body) - - if err != nil || !Succeeded(res) { - return dbsqlerrint.NewDriverError(ctx, fmt.Sprintf("staging operation over HTTP was unsuccessful: %d-%s, nil", res.StatusCode, content), nil) - } - - return nil -} - -func localPathIsAllowed(stagingAllowedLocalPaths []string, localFile string) bool { - for i := range stagingAllowedLocalPaths { - // Convert both filepaths to absolute paths to avoid potential issues. - // - path, err := filepath.Abs(stagingAllowedLocalPaths[i]) - if err != nil { - return false - } - localFile, err := filepath.Abs(localFile) - if err != nil { - return false - } - relativePath, err := filepath.Rel(path, localFile) - if err != nil { - return false - } - if !strings.Contains(relativePath, "../") { - return true - } + if stagingErr != nil { + log.Err(stagingErr).Msgf("databricks: failed to execute query: query %s", query) + return nil, dbsqlerrint.NewExecutionError(ctx, dbsqlerr.ErrQueryExecution, err, opStatusResp) } - return false -} -func (c *conn) ExecStagingOperation( - ctx context.Context, - row driver.Rows) dbsqlerr.DBError { + res := result{AffectedRows: opStatusResp.GetNumModifiedRows()} - var sqlRow []driver.Value - colNames := row.Columns() - sqlRow = make([]driver.Value, len(colNames)) - err := row.Next(sqlRow) - if err != nil { - return dbsqlerrint.NewDriverError(ctx, "error fetching staging operation results", err) - } - var stringValues []string = make([]string, 4) - for i := range stringValues { - if s, ok := sqlRow[i].(string); ok { - stringValues[i] = s - } else { - return dbsqlerrint.NewDriverError(ctx, "received unexpected response from the server.", nil) - } - } - operation := stringValues[0] - presignedUrl := stringValues[1] - headersByteArr := []byte(stringValues[2]) - var headers map[string]string - if err := json.Unmarshal(headersByteArr, &headers); err != nil { - return dbsqlerrint.NewDriverError(ctx, "error parsing server response.", nil) - } - localFile := stringValues[3] - stagingAllowedLocalPaths := driverctx.StagingPathsFromContext(ctx) - switch operation { - case "PUT": - if localPathIsAllowed(stagingAllowedLocalPaths, localFile) { - return c.HandleStagingPut(ctx, presignedUrl, headers, localFile) - } else { - return dbsqlerrint.NewDriverError(ctx, "local file operations are restricted to paths within the configured stagingAllowedLocalPath", nil) - } - case "GET": - if localPathIsAllowed(stagingAllowedLocalPaths, localFile) { - return c.HandleStagingGet(ctx, presignedUrl, headers, localFile) - } else { - return dbsqlerrint.NewDriverError(ctx, "local file operations are restricted to paths within the configured stagingAllowedLocalPath", nil) - } - case "DELETE": - return c.HandleStagingDelete(ctx, presignedUrl, headers) - default: - return dbsqlerrint.NewDriverError(ctx, fmt.Sprintf("operation %s is not supported. Supported operations are GET, PUT, and REMOVE", operation), nil) - } + return &res, nil } // QueryContext executes a query that may return rows, such as a @@ -333,11 +155,9 @@ func (c *conn) ExecStagingOperation( // QueryContext honors the context timeout and return when it is canceled. // Statement QueryContext is the same as connection QueryContext func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { - corrId := driverctx.CorrelationIdFromContext(ctx) - log := logger.WithContext(c.id, corrId, "") - msg, start := log.Track("QueryContext") - ctx = driverctx.NewContextWithConnId(ctx, c.id) + log, _ := client.LoggerAndContext(ctx, nil) + msg, start := log.Track("QueryContext") if len(args) > 0 && c.session.ServerProtocolVersion < cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V8 { return nil, dbsqlerrint.NewDriverError(ctx, dbsqlerr.ErrParametersNotSupported, nil) @@ -346,43 +166,33 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam // first we try to get the results synchronously. // at any point in time that the context is done we must cancel and return exStmtResp, opStatusResp, err := c.runQuery(ctx, query, args) - - if exStmtResp != nil && exStmtResp.OperationHandle != nil { - ctx = driverctx.NewContextWithQueryId(ctx, client.SprintGuid(exStmtResp.OperationHandle.OperationId.GUID)) - log = logger.WithContext(c.id, driverctx.CorrelationIdFromContext(ctx), client.SprintGuid(exStmtResp.OperationHandle.OperationId.GUID)) - } + log, ctx = client.LoggerAndContext(ctx, exStmtResp) defer log.Duration(msg, start) if err != nil { log.Err(err).Msg("databricks: failed to run query") // To log query we need to redact credentials return nil, dbsqlerrint.NewExecutionError(ctx, dbsqlerr.ErrQueryExecution, err, opStatusResp) } - // hold on to the operation handle - opHandle := exStmtResp.OperationHandle - rows, err := rows.NewRows(c.id, corrId, opHandle, c.client, c.cfg, exStmtResp.DirectResults) + corrId := driverctx.CorrelationIdFromContext(ctx) + rows, err := rows.NewRows(c.id, corrId, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults) return rows, err } func (c *conn) runQuery(ctx context.Context, query string, args []driver.NamedValue) (*cli_service.TExecuteStatementResp, *cli_service.TGetOperationStatusResp, error) { - log := logger.WithContext(c.id, driverctx.CorrelationIdFromContext(ctx), "") // first we try to get the results synchronously. // at any point in time that the context is done we must cancel and return exStmtResp, err := c.executeStatement(ctx, query, args) + var log *logger.DBSQLLogger + log, ctx = client.LoggerAndContext(ctx, exStmtResp) if err != nil { return exStmtResp, nil, err } + opHandle := exStmtResp.OperationHandle - if opHandle != nil && opHandle.OperationId != nil { - ctx = driverctx.NewContextWithQueryId(ctx, client.SprintGuid(opHandle.OperationId.GUID)) - log = logger.WithContext( - c.id, - driverctx.CorrelationIdFromContext(ctx), driverctx.QueryIdFromContext(ctx), - ) - } if exStmtResp.DirectResults != nil { opStatus := exStmtResp.DirectResults.GetOperationStatus() @@ -470,8 +280,7 @@ func invalidOperationState(ctx context.Context, opStatus *cli_service.TGetOperat } func (c *conn) executeStatement(ctx context.Context, query string, args []driver.NamedValue) (*cli_service.TExecuteStatementResp, error) { - corrId := driverctx.CorrelationIdFromContext(ctx) - log := logger.WithContext(c.id, corrId, "") + ctx = driverctx.NewContextWithConnId(ctx, c.id) req := cli_service.TExecuteStatementReq{ SessionHandle: c.session.SessionHandle, @@ -499,22 +308,23 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver req.CanDownloadResult_ = &c.cfg.UseCloudFetch } - ctx = driverctx.NewContextWithConnId(ctx, c.id) resp, err := c.client.ExecuteStatement(ctx, &req) + var log *logger.DBSQLLogger + log, ctx = client.LoggerAndContext(ctx, resp) var shouldCancel = func(resp *cli_service.TExecuteStatementResp) bool { if resp == nil { return false } hasHandle := resp.OperationHandle != nil - isOpen := resp.DirectResults != nil && resp.DirectResults.CloseOperation == nil + isOpen := resp.DirectResults == nil || resp.DirectResults.CloseOperation == nil return hasHandle && isOpen } select { default: case <-ctx.Done(): - newCtx := driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), c.id), corrId) + newCtx := driverctx.NewContextFromBackground(ctx) // in case context is done, we need to cancel the operation if necessary if err == nil && shouldCancel(resp) { log.Debug().Msg("databricks: canceling query") @@ -523,7 +333,7 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver }) if err1 != nil { - log.Err(err).Msgf("databricks: cancel failed") + log.Err(err1).Msgf("databricks: cancel failed") } else { log.Debug().Msgf("databricks: cancel success") } @@ -572,7 +382,7 @@ func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperati }, statusResp, err }, OnCancelFn: func() (any, error) { - log.Debug().Msg("databricks: canceling query") + log.Debug().Msg("databricks: sentinel canceling query") ret, err := c.client.CancelOperation(newCtx, &cli_service.TCancelOperationReq{ OperationHandle: opHandle, }) @@ -581,6 +391,7 @@ func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperati } status, resp, err := pollSentinel.Watch(ctx, c.cfg.PollInterval, 0) if err != nil { + log.Err(err).Msg("error polling operation status") if status == sentinel.WatchTimeout { err = dbsqlerrint.NewRequestError(ctx, dbsqlerr.ErrSentinelTimeout, err) } @@ -615,3 +426,196 @@ var _ driver.QueryerContext = (*conn)(nil) var _ driver.ConnPrepareContext = (*conn)(nil) var _ driver.ConnBeginTx = (*conn)(nil) var _ driver.NamedValueChecker = (*conn)(nil) + +func Succeeded(response *http.Response) bool { + if response.StatusCode == 200 || response.StatusCode == 201 || response.StatusCode == 202 || response.StatusCode == 204 { + return true + } + return false +} + +func (c *conn) handleStagingPut(ctx context.Context, presignedUrl string, headers map[string]string, localFile string) dbsqlerr.DBError { + if localFile == "" { + return dbsqlerrint.NewDriverError(ctx, "cannot perform PUT without specifying a local_file", nil) + } + client := &http.Client{} + + dat, err := os.ReadFile(localFile) + + if err != nil { + return dbsqlerrint.NewDriverError(ctx, "error reading local file", err) + } + + req, _ := http.NewRequest("PUT", presignedUrl, bytes.NewReader(dat)) + + for k, v := range headers { + req.Header.Set(k, v) + } + res, err := client.Do(req) + if err != nil { + return dbsqlerrint.NewDriverError(ctx, "error sending http request", err) + } + defer res.Body.Close() + content, err := io.ReadAll(res.Body) + + if err != nil || !Succeeded(res) { + return dbsqlerrint.NewDriverError(ctx, fmt.Sprintf("staging operation over HTTP was unsuccessful: %d-%s", res.StatusCode, content), nil) + } + return nil + +} + +func (c *conn) handleStagingGet(ctx context.Context, presignedUrl string, headers map[string]string, localFile string) dbsqlerr.DBError { + if localFile == "" { + return dbsqlerrint.NewDriverError(ctx, "cannot perform GET without specifying a local_file", nil) + } + client := &http.Client{} + req, _ := http.NewRequest("GET", presignedUrl, nil) + + for k, v := range headers { + req.Header.Set(k, v) + } + res, err := client.Do(req) + if err != nil { + return dbsqlerrint.NewDriverError(ctx, "error sending http request", err) + } + defer res.Body.Close() + content, err := io.ReadAll(res.Body) + + if err != nil || !Succeeded(res) { + return dbsqlerrint.NewDriverError(ctx, fmt.Sprintf("staging operation over HTTP was unsuccessful: %d-%s", res.StatusCode, content), nil) + } + + err = os.WriteFile(localFile, content, 0644) //nolint:gosec + if err != nil { + return dbsqlerrint.NewDriverError(ctx, "error writing local file", err) + } + return nil +} + +func (c *conn) handleStagingDelete(ctx context.Context, presignedUrl string, headers map[string]string) dbsqlerr.DBError { + client := &http.Client{} + req, _ := http.NewRequest("DELETE", presignedUrl, nil) + for k, v := range headers { + req.Header.Set(k, v) + } + res, err := client.Do(req) + if err != nil { + return dbsqlerrint.NewDriverError(ctx, "error sending http request", err) + } + defer res.Body.Close() + content, err := io.ReadAll(res.Body) + + if err != nil || !Succeeded(res) { + return dbsqlerrint.NewDriverError(ctx, fmt.Sprintf("staging operation over HTTP was unsuccessful: %d-%s, nil", res.StatusCode, content), nil) + } + + return nil +} + +func localPathIsAllowed(stagingAllowedLocalPaths []string, localFile string) bool { + for i := range stagingAllowedLocalPaths { + // Convert both filepaths to absolute paths to avoid potential issues. + // + path, err := filepath.Abs(stagingAllowedLocalPaths[i]) + if err != nil { + return false + } + localFile, err := filepath.Abs(localFile) + if err != nil { + return false + } + relativePath, err := filepath.Rel(path, localFile) + if err != nil { + return false + } + if !strings.Contains(relativePath, "../") { + return true + } + } + return false +} + +func (c *conn) execStagingOperation( + exStmtResp *cli_service.TExecuteStatementResp, + ctx context.Context) dbsqlerr.DBError { + + if exStmtResp == nil || exStmtResp.OperationHandle == nil { + return nil + } + + corrId := driverctx.CorrelationIdFromContext(ctx) + var row driver.Rows + var err error + + var isStagingOperation bool + if exStmtResp.DirectResults != nil && exStmtResp.DirectResults.ResultSetMetadata != nil && exStmtResp.DirectResults.ResultSetMetadata.IsStagingOperation != nil { + isStagingOperation = *exStmtResp.DirectResults.ResultSetMetadata.IsStagingOperation + } else { + req := cli_service.TGetResultSetMetadataReq{ + OperationHandle: exStmtResp.OperationHandle, + } + resp, err := c.client.GetResultSetMetadata(ctx, &req) + if err != nil { + return dbsqlerrint.NewDriverError(ctx, "error performing staging operation", err) + } + isStagingOperation = *resp.IsStagingOperation + } + + if !isStagingOperation { + return nil + } + + if len(driverctx.StagingPathsFromContext(ctx)) != 0 { + row, err = rows.NewRows(c.id, corrId, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults) + if err != nil { + return dbsqlerrint.NewDriverError(ctx, "error reading row.", err) + } + + } else { + return dbsqlerrint.NewDriverError(ctx, "staging ctx must be provided.", nil) + } + + var sqlRow []driver.Value + colNames := row.Columns() + sqlRow = make([]driver.Value, len(colNames)) + err = row.Next(sqlRow) + if err != nil { + return dbsqlerrint.NewDriverError(ctx, "error fetching staging operation results", err) + } + var stringValues []string = make([]string, 4) + for i := range stringValues { + if s, ok := sqlRow[i].(string); ok { + stringValues[i] = s + } else { + return dbsqlerrint.NewDriverError(ctx, "received unexpected response from the server.", nil) + } + } + operation := stringValues[0] + presignedUrl := stringValues[1] + headersByteArr := []byte(stringValues[2]) + var headers map[string]string + if err := json.Unmarshal(headersByteArr, &headers); err != nil { + return dbsqlerrint.NewDriverError(ctx, "error parsing server response.", nil) + } + localFile := stringValues[3] + stagingAllowedLocalPaths := driverctx.StagingPathsFromContext(ctx) + switch operation { + case "PUT": + if localPathIsAllowed(stagingAllowedLocalPaths, localFile) { + return c.handleStagingPut(ctx, presignedUrl, headers, localFile) + } else { + return dbsqlerrint.NewDriverError(ctx, "local file operations are restricted to paths within the configured stagingAllowedLocalPath", nil) + } + case "GET": + if localPathIsAllowed(stagingAllowedLocalPaths, localFile) { + return c.handleStagingGet(ctx, presignedUrl, headers, localFile) + } else { + return dbsqlerrint.NewDriverError(ctx, "local file operations are restricted to paths within the configured stagingAllowedLocalPath", nil) + } + case "DELETE": + return c.handleStagingDelete(ctx, presignedUrl, headers) + default: + return dbsqlerrint.NewDriverError(ctx, fmt.Sprintf("operation %s is not supported. Supported operations are GET, PUT, and REMOVE", operation), nil) + } +} diff --git a/connection_test.go b/connection_test.go index d349ba5..620a194 100644 --- a/connection_test.go +++ b/connection_test.go @@ -8,7 +8,9 @@ import ( "time" "github.com/apache/thrift/lib/go/thrift" + "github.com/pkg/errors" + dbsqlerr "github.com/databricks/databricks-sql-go/errors" "github.com/databricks/databricks-sql-go/internal/cli_service" "github.com/databricks/databricks-sql-go/internal/client" "github.com/databricks/databricks-sql-go/internal/config" @@ -1337,7 +1339,8 @@ func TestConn_Ping(t *testing.T) { err := testConn.Ping(context.Background()) assert.Error(t, err) - assert.Equal(t, driver.ErrBadConn, err) + assert.True(t, errors.Is(err, driver.ErrBadConn)) + assert.True(t, errors.Is(err, dbsqlerr.ExecutionError)) assert.Equal(t, 1, executeStatementCount) }) diff --git a/driverctx/ctx.go b/driverctx/ctx.go index 77a8286..4ccbbbe 100644 --- a/driverctx/ctx.go +++ b/driverctx/ctx.go @@ -104,3 +104,17 @@ func NewContextWithConnIdCallback(ctx context.Context, callback IdCallbackFunc) func NewContextWithStagingInfo(ctx context.Context, stagingAllowedLocalPath []string) context.Context { return context.WithValue(ctx, StagingAllowedLocalPathKey, stagingAllowedLocalPath) } + +func NewContextFromBackground(ctx context.Context) context.Context { + connId := ConnIdFromContext(ctx) + corrId := CorrelationIdFromContext(ctx) + queryId := QueryIdFromContext(ctx) + stagingPaths := StagingPathsFromContext(ctx) + + newCtx := NewContextWithConnId(context.Background(), connId) + newCtx = NewContextWithCorrelationId(newCtx, corrId) + newCtx = NewContextWithQueryId(newCtx, queryId) + newCtx = NewContextWithStagingInfo(newCtx, stagingPaths) + + return newCtx +} diff --git a/internal/client/client.go b/internal/client/client.go index cb6c961..fda1053 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -12,6 +12,7 @@ import ( "net/http/httptrace" "net/url" "os" + "reflect" "regexp" "strconv" "strings" @@ -79,9 +80,14 @@ var clientMethodRequestErrorMsgs map[clientMethod]string = map[clientMethod]stri // OpenSession is a wrapper around the thrift operation OpenSession // If RecordResults is true, the results will be marshalled to JSON format and written to OpenSession.json func (tsc *ThriftServiceClient) OpenSession(ctx context.Context, req *cli_service.TOpenSessionReq) (*cli_service.TOpenSessionResp, error) { - ctx = context.WithValue(ctx, ClientMethod, clientMethodOpenSession) + ctx = startClientMethod(ctx, clientMethodOpenSession) + var log *logger.DBSQLLogger msg, start := logger.Track("OpenSession") + resp, err := tsc.TCLIServiceClient.OpenSession(ctx, req) + log, ctx = LoggerAndContext(ctx, resp) + logDisplayMessage(resp, log) + defer log.Duration(msg, start) if err != nil { err = handleClientMethodError(ctx, err) return resp, err @@ -89,19 +95,19 @@ func (tsc *ThriftServiceClient) OpenSession(ctx context.Context, req *cli_servic recordResult(ctx, resp) - log := logger.WithContext(SprintGuid(resp.SessionHandle.SessionId.GUID), driverctx.CorrelationIdFromContext(ctx), "") - defer log.Duration(msg, start) - return resp, CheckStatus(resp) } // CloseSession is a wrapper around the thrift operation CloseSession // If RecordResults is true, the results will be marshalled to JSON format and written to CloseSession.json func (tsc *ThriftServiceClient) CloseSession(ctx context.Context, req *cli_service.TCloseSessionReq) (*cli_service.TCloseSessionResp, error) { - ctx = context.WithValue(ctx, ClientMethod, clientMethodCloseSession) - log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), "") + ctx = startClientMethod(ctx, clientMethodCloseSession) + var log *logger.DBSQLLogger + log, ctx = LoggerAndContext(ctx, req) defer log.Duration(logger.Track("CloseSession")) + resp, err := tsc.TCLIServiceClient.CloseSession(ctx, req) + logDisplayMessage(resp, log) if err != nil { err = handleClientMethodError(ctx, err) return resp, err @@ -115,10 +121,13 @@ func (tsc *ThriftServiceClient) CloseSession(ctx context.Context, req *cli_servi // FetchResults is a wrapper around the thrift operation FetchResults // If RecordResults is true, the results will be marshalled to JSON format and written to FetchResults.json func (tsc *ThriftServiceClient) FetchResults(ctx context.Context, req *cli_service.TFetchResultsReq) (*cli_service.TFetchResultsResp, error) { - ctx = context.WithValue(ctx, ClientMethod, clientMethodFetchResults) - log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), SprintGuid(req.OperationHandle.OperationId.GUID)) + ctx = startClientMethod(ctx, clientMethodFetchResults) + var log *logger.DBSQLLogger + log, ctx = LoggerAndContext(ctx, req) defer log.Duration(logger.Track("FetchResults")) + resp, err := tsc.TCLIServiceClient.FetchResults(ctx, req) + logDisplayMessage(resp, log) if err != nil { err = handleClientMethodError(ctx, err) return resp, err @@ -132,10 +141,13 @@ func (tsc *ThriftServiceClient) FetchResults(ctx context.Context, req *cli_servi // GetResultSetMetadata is a wrapper around the thrift operation GetResultSetMetadata // If RecordResults is true, the results will be marshalled to JSON format and written to GetResultSetMetadata.json func (tsc *ThriftServiceClient) GetResultSetMetadata(ctx context.Context, req *cli_service.TGetResultSetMetadataReq) (*cli_service.TGetResultSetMetadataResp, error) { - ctx = context.WithValue(ctx, ClientMethod, clientMethodGetResultSetMetadata) - log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), SprintGuid(req.OperationHandle.OperationId.GUID)) + ctx = startClientMethod(ctx, clientMethodGetResultSetMetadata) + var log *logger.DBSQLLogger + log, ctx = LoggerAndContext(ctx, req) defer log.Duration(logger.Track("GetResultSetMetadata")) + resp, err := tsc.TCLIServiceClient.GetResultSetMetadata(ctx, req) + logDisplayMessage(resp, log) if err != nil { err = handleClientMethodError(ctx, err) return resp, err @@ -149,32 +161,36 @@ func (tsc *ThriftServiceClient) GetResultSetMetadata(ctx context.Context, req *c // ExecuteStatement is a wrapper around the thrift operation ExecuteStatement // If RecordResults is true, the results will be marshalled to JSON format and written to ExecuteStatement.json func (tsc *ThriftServiceClient) ExecuteStatement(ctx context.Context, req *cli_service.TExecuteStatementReq) (*cli_service.TExecuteStatementResp, error) { - ctx = context.WithValue(ctx, ClientMethod, clientMethodExecuteStatement) - msg, start := logger.Track("ExecuteStatement") + ctx = startClientMethod(ctx, clientMethodExecuteStatement) + var log *logger.DBSQLLogger + log, ctx = LoggerAndContext(ctx, req) + msg, start := log.Track("ExecuteStatement") // We use context.Background to fix a problem where on context done the query would not be cancelled. resp, err := tsc.TCLIServiceClient.ExecuteStatement(context.Background(), req) + log, ctx = LoggerAndContext(ctx, resp) + logDisplayMessage(resp, log) + logExecStatementState(resp, log) + + defer log.Duration(msg, start) if err != nil { err = handleClientMethodError(ctx, err) return resp, err } - recordResult(ctx, resp) - - if resp != nil && resp.OperationHandle != nil { - log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), SprintGuid(resp.OperationHandle.OperationId.GUID)) - defer log.Duration(msg, start) - } return resp, CheckStatus(resp) } // GetOperationStatus is a wrapper around the thrift operation GetOperationStatus // If RecordResults is true, the results will be marshalled to JSON format and written to GetOperationStatus.json func (tsc *ThriftServiceClient) GetOperationStatus(ctx context.Context, req *cli_service.TGetOperationStatusReq) (*cli_service.TGetOperationStatusResp, error) { - ctx = context.WithValue(ctx, ClientMethod, clientMethodGetOperationStatus) - log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), SprintGuid(req.OperationHandle.OperationId.GUID)) + ctx = startClientMethod(ctx, clientMethodGetOperationStatus) + var log *logger.DBSQLLogger + log, ctx = LoggerAndContext(ctx, req) defer log.Duration(logger.Track("GetOperationStatus")) + resp, err := tsc.TCLIServiceClient.GetOperationStatus(ctx, req) + logDisplayMessage(resp, log) if err != nil { err = handleClientMethodError(driverctx.NewContextWithQueryId(ctx, SprintGuid(req.OperationHandle.OperationId.GUID)), err) return resp, err @@ -188,10 +204,13 @@ func (tsc *ThriftServiceClient) GetOperationStatus(ctx context.Context, req *cli // CloseOperation is a wrapper around the thrift operation CloseOperation // If RecordResults is true, the results will be marshalled to JSON format and written to CloseOperation.json func (tsc *ThriftServiceClient) CloseOperation(ctx context.Context, req *cli_service.TCloseOperationReq) (*cli_service.TCloseOperationResp, error) { - ctx = context.WithValue(ctx, ClientMethod, clientMethodCloseOperation) - log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), SprintGuid(req.OperationHandle.OperationId.GUID)) + ctx = startClientMethod(ctx, clientMethodCloseOperation) + var log *logger.DBSQLLogger + log, ctx = LoggerAndContext(ctx, req) defer log.Duration(logger.Track("CloseOperation")) + resp, err := tsc.TCLIServiceClient.CloseOperation(ctx, req) + logDisplayMessage(resp, log) if err != nil { err = handleClientMethodError(ctx, err) return resp, err @@ -205,10 +224,13 @@ func (tsc *ThriftServiceClient) CloseOperation(ctx context.Context, req *cli_ser // CancelOperation is a wrapper around the thrift operation CancelOperation // If RecordResults is true, the results will be marshalled to JSON format and written to CancelOperation.json func (tsc *ThriftServiceClient) CancelOperation(ctx context.Context, req *cli_service.TCancelOperationReq) (*cli_service.TCancelOperationResp, error) { - ctx = context.WithValue(ctx, ClientMethod, clientMethodCancelOperation) - log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), SprintGuid(req.OperationHandle.OperationId.GUID)) + ctx = startClientMethod(ctx, clientMethodCancelOperation) + var log *logger.DBSQLLogger + log, ctx = LoggerAndContext(ctx, req) defer log.Duration(logger.Track("CancelOperation")) + resp, err := tsc.TCLIServiceClient.CancelOperation(ctx, req) + logDisplayMessage(resp, log) if err != nil { err = handleClientMethodError(ctx, err) return resp, err @@ -286,6 +308,13 @@ func InitThriftClient(cfg *config.Config, httpclient *http.Client) (*ThriftServi return tsClient, nil } +func startClientMethod(ctx context.Context, method clientMethod) context.Context { + ctx = context.WithValue(ctx, ClientMethod, method) + log, _ := LoggerAndContext(ctx, nil) + log.Debug().Msgf("client.%s", method.String()) + return ctx +} + // handler function for errors returned by the thrift client methods func handleClientMethodError(ctx context.Context, err error) dbsqlerr.DBRequestError { if err == nil { @@ -303,7 +332,13 @@ func handleClientMethodError(ctx context.Context, err error) dbsqlerr.DBRequestE method := getClientMethod(ctx) msg := clientMethodRequestErrorMsgs[method] - return dbsqlerrint.NewRequestError(ctx, msg, err) + dbErr := dbsqlerrint.NewRequestError(ctx, msg, err) + + log, _ := LoggerAndContext(ctx, nil) + + log.Err(err).Msg("") + + return dbErr } // Extract a clientMethod value from the given Context. @@ -355,6 +390,95 @@ func SprintGuid(bts []byte) string { return fmt.Sprintf("%x", bts) } +// Create an updated context and a logger that includes query and connection id +func LoggerAndContext(ctx context.Context, c any) (*logger.DBSQLLogger, context.Context) { + connId := driverctx.ConnIdFromContext(ctx) + corrId := driverctx.CorrelationIdFromContext(ctx) + queryId := driverctx.QueryIdFromContext(ctx) + if connId == "" { + connId = guidFromHasSessionHandle(c) + ctx = driverctx.NewContextWithConnId(ctx, connId) + } + if queryId == "" { + queryId = guidFromHasOpHandle(c) + ctx = driverctx.NewContextWithQueryId(ctx, queryId) + } + log := logger.WithContext(connId, corrId, queryId) + + return log, ctx +} + +type hasOpHandle interface { + GetOperationHandle() *cli_service.TOperationHandle +} +type hasSessionHandle interface { + GetSessionHandle() *cli_service.TSessionHandle +} + +func guidFromHasOpHandle(c any) (guid string) { + if c == nil || reflect.ValueOf(c).IsNil() { + return + } + if ho, ok := c.(hasOpHandle); ok { + opHandle := ho.GetOperationHandle() + if opHandle != nil && opHandle.OperationId != nil && opHandle.OperationId.GUID != nil { + guid = SprintGuid(opHandle.OperationId.GUID) + } + } + return +} + +func guidFromHasSessionHandle(c any) (guid string) { + if c == nil || reflect.ValueOf(c).IsNil() { + return + } + if ho, ok := c.(hasSessionHandle); ok { + sessionHandle := ho.GetSessionHandle() + if sessionHandle != nil && sessionHandle.SessionId != nil && sessionHandle.SessionId.GUID != nil { + guid = SprintGuid(sessionHandle.SessionId.GUID) + } + } + return +} + +func logExecStatementState(resp *cli_service.TExecuteStatementResp, log *logger.DBSQLLogger) { + if resp != nil { + if resp.DirectResults != nil { + state := resp.DirectResults.GetOperationStatus().GetOperationState() + log.Debug().Msgf("execute statement state: %s", state) + status := resp.DirectResults.GetOperationStatus().GetStatus().StatusCode + log.Debug().Msgf("execute statement status: %s", status) + logDisplayMessage(resp.DirectResults, log) + } else { + status := resp.GetStatus().StatusCode + log.Debug().Msgf("execute statement status: %s", status) + } + } +} + +type hasGetStatus interface{ GetStatus() *cli_service.TStatus } +type hasGetDisplayMessage interface{ GetDisplayMessage() string } +type hasGetOperationStatus interface { + GetOperationStatus() *cli_service.TGetOperationStatusResp +} + +func logDisplayMessage(c any, log *logger.DBSQLLogger) { + if c == nil || reflect.ValueOf(c).IsNil() { + return + } + + if hd, ok := c.(hasGetDisplayMessage); ok { + dm := hd.GetDisplayMessage() + if dm != "" { + log.Debug().Msg(dm) + } + } else if gs, ok := c.(hasGetStatus); ok { + logDisplayMessage(gs.GetStatus(), log) + } else if gos, ok := c.(hasGetOperationStatus); ok { + logDisplayMessage(gos.GetOperationStatus(), log) + } +} + var retryableStatusCodes = map[int]any{http.StatusTooManyRequests: struct{}{}, http.StatusServiceUnavailable: struct{}{}} func isRetryableServerResponse(resp *http.Response) bool { diff --git a/internal/fetcher/fetcher.go b/internal/fetcher/fetcher.go index 2f53754..8430ff0 100644 --- a/internal/fetcher/fetcher.go +++ b/internal/fetcher/fetcher.go @@ -46,13 +46,13 @@ func (f *concurrentFetcher[I, O]) Start() (<-chan O, context.CancelFunc, error) // increment wait group wg.Add(1) - f.logger().Debug().Msgf("concurrent fetcher starting worker %d", i) + f.logger().Trace().Msgf("concurrent fetcher starting worker %d", i) go func(x int) { // when work function remove one from the wait group defer wg.Done() // do the actual work work(f, x) - f.logger().Debug().Msgf("concurrent fetcher worker %d done", x) + f.logger().Trace().Msgf("concurrent fetcher worker %d done", x) }(i) } @@ -62,7 +62,7 @@ func (f *concurrentFetcher[I, O]) Start() (<-chan O, context.CancelFunc, error) // be stuck waiting on the output channel. go func() { wg.Wait() - f.logger().Debug().Msg("concurrent fetcher closing output channel") + f.logger().Trace().Msg("concurrent fetcher closing output channel") close(f.outChan) }() @@ -70,9 +70,9 @@ func (f *concurrentFetcher[I, O]) Start() (<-chan O, context.CancelFunc, error) // cancel fetching. var cancelOnce sync.Once = sync.Once{} f.cancelFunc = func() { - f.logger().Debug().Msg("concurrent fetcher cancel func") + f.logger().Trace().Msg("concurrent fetcher cancel func") cancelOnce.Do(func() { - f.logger().Debug().Msg("concurrent fetcher closing cancel channel") + f.logger().Trace().Msg("concurrent fetcher closing cancel channel") close(f.cancelChan) }) } @@ -142,19 +142,19 @@ func work[I FetchableItems[O], O any](f *concurrentFetcher[I, O], workerIndex in case input, ok := <-f.inputChan: if ok { - f.logger().Debug().Msgf("concurrent fetcher worker %d loading item", workerIndex) + f.logger().Trace().Msgf("concurrent fetcher worker %d loading item", workerIndex) result, err := input.Fetch(f.ctx) if err != nil { - f.logger().Debug().Msgf("concurrent fetcher worker %d received error", workerIndex) + f.logger().Trace().Msgf("concurrent fetcher worker %d received error", workerIndex) f.setErr(err) f.cancelFunc() return } else { - f.logger().Debug().Msgf("concurrent fetcher worker %d item loaded", workerIndex) + f.logger().Trace().Msgf("concurrent fetcher worker %d item loaded", workerIndex) f.outChan <- result } } else { - f.logger().Debug().Msgf("concurrent fetcher ending %d", workerIndex) + f.logger().Trace().Msgf("concurrent fetcher ending %d", workerIndex) return } diff --git a/internal/sentinel/sentinel.go b/internal/sentinel/sentinel.go index 6886f1e..529f229 100644 --- a/internal/sentinel/sentinel.go +++ b/internal/sentinel/sentinel.go @@ -5,7 +5,7 @@ import ( "fmt" "time" - "github.com/databricks/databricks-sql-go/logger" + "github.com/databricks/databricks-sql-go/internal/client" "github.com/pkg/errors" ) @@ -84,6 +84,8 @@ func (s Sentinel) Watch(ctx context.Context, interval, timeout time.Duration) (W } } + log, _ := client.LoggerAndContext(ctx, nil) + // If the watch times out or is cancelled this function // will stop the interval timer and call the cancel function // if necessary. @@ -93,9 +95,9 @@ func (s Sentinel) Watch(ctx context.Context, interval, timeout time.Duration) (W s.onCancelFnCalled = true _, err := s.OnCancelFn() if err != nil { - logger.Err(err).Msg("databricks: cancel failed") + log.Err(err).Msg("databricks: cancel failed") } else { - logger.Debug().Msgf("databricks: cancel success") + log.Debug().Msgf("databricks: cancel success") } } } @@ -122,11 +124,12 @@ func (s Sentinel) Watch(ctx context.Context, interval, timeout time.Duration) (W case res := <-resCh: return WatchSuccess, res, nil case <-ctx.Done(): + log.Debug().Msgf("sentinel <-ctx.Done: %s", ctx.Err().Error()) timeoutOrCancel() return WatchCanceled, nil, ctx.Err() case <-timeoutTimerCh: msg := fmt.Sprintf("wait timed out after %s", timeout.String()) - logger.Info().Msg(msg) + log.Info().Msg(msg) timeoutOrCancel() err := errors.New(msg) return WatchTimeout, nil, err