diff --git a/auth/oauth/m2m/m2m.go b/auth/oauth/m2m/m2m.go index aa72eb9..7b68404 100644 --- a/auth/oauth/m2m/m2m.go +++ b/auth/oauth/m2m/m2m.go @@ -17,7 +17,11 @@ import ( ) func NewAuthenticator(clientID, clientSecret, hostName string) auth.Authenticator { - scopes := oauth.GetScopes(hostName, []string{}) + return NewAuthenticatorWithScopes(clientID, clientSecret, hostName, []string{}) +} + +func NewAuthenticatorWithScopes(clientID, clientSecret, hostName string, scopes []string) auth.Authenticator { + scopes = GetScopes(hostName, scopes) return &authClient{ clientID: clientID, clientSecret: clientSecret, @@ -89,3 +93,11 @@ func GetConfig(ctx context.Context, issuerURL, clientID, clientSecret string, sc return config, nil } + +func GetScopes(hostName string, scopes []string) []string { + if !oauth.HasScope(scopes, "all-apis") { + scopes = append(scopes, "all-apis") + } + + return scopes +} diff --git a/auth/oauth/oauth.go b/auth/oauth/oauth.go index 2e94dba..2ffd28a 100644 --- a/auth/oauth/oauth.go +++ b/auth/oauth/oauth.go @@ -45,7 +45,7 @@ func GetEndpoint(ctx context.Context, hostName string) (oauth2.Endpoint, error) func GetScopes(hostName string, scopes []string) []string { for _, s := range []string{oidc.ScopeOfflineAccess} { - if !hasScope(scopes, s) { + if !HasScope(scopes, s) { scopes = append(scopes, s) } } @@ -53,11 +53,11 @@ func GetScopes(hostName string, scopes []string) []string { cloudType := InferCloudFromHost(hostName) if cloudType == Azure { userImpersonationScope := fmt.Sprintf("%s/user_impersonation", azureTenantId) - if !hasScope(scopes, userImpersonationScope) { + if !HasScope(scopes, userImpersonationScope) { scopes = append(scopes, userImpersonationScope) } } else { - if !hasScope(scopes, "sql") { + if !HasScope(scopes, "sql") { scopes = append(scopes, "sql") } } @@ -65,7 +65,7 @@ func GetScopes(hostName string, scopes []string) []string { return scopes } -func hasScope(scopes []string, scope string) bool { +func HasScope(scopes []string, scope string) bool { for _, s := range scopes { if s == scope { return true diff --git a/auth/oauth/u2m/authenticator.go b/auth/oauth/u2m/authenticator.go index 29a973c..5a5aabb 100644 --- a/auth/oauth/u2m/authenticator.go +++ b/auth/oauth/u2m/authenticator.go @@ -25,11 +25,11 @@ import ( ) const ( - azureClientId = "96eecda7-19ea-49cc-abb5-240097d554f5" - azureRedirctURL = "localhost:8030" + azureClientId = "96eecda7-19ea-49cc-abb5-240097d554f5" + azureRedirectURL = "localhost:8030" - awsClientId = "databricks-sql-connector" - awsRedirctURL = "localhost:8030" + awsClientId = "databricks-sql-connector" + awsRedirectURL = "localhost:8030" ) func NewAuthenticator(hostName string, timeout time.Duration) (auth.Authenticator, error) { @@ -39,10 +39,10 @@ func NewAuthenticator(hostName string, timeout time.Duration) (auth.Authenticato var clientID, redirectURL string if cloud == oauth.AWS { clientID = awsClientId - redirectURL = awsRedirctURL + redirectURL = awsRedirectURL } else if cloud == oauth.Azure { clientID = azureClientId - redirectURL = azureRedirctURL + redirectURL = azureRedirectURL } else { return nil, errors.New("unhandled cloud type: " + cloud.String()) } diff --git a/driverctx/ctx.go b/driverctx/ctx.go index 4ccbbbe..f8f4674 100644 --- a/driverctx/ctx.go +++ b/driverctx/ctx.go @@ -38,6 +38,7 @@ func CorrelationIdFromContext(ctx context.Context) string { } // NewContextWithConnId creates a new context with connectionId value. +// The connection ID will be displayed in log messages and other dianostic information. func NewContextWithConnId(ctx context.Context, connId string) context.Context { if callback, ok := ctx.Value(ConnIdCallbackKey).(IdCallbackFunc); ok { callback(connId) @@ -59,6 +60,7 @@ func ConnIdFromContext(ctx context.Context) string { } // NewContextWithQueryId creates a new context with queryId value. +// The query id will be displayed in log messages and other diagnostic information. func NewContextWithQueryId(ctx context.Context, queryId string) context.Context { if callback, ok := ctx.Value(QueryIdCallbackKey).(IdCallbackFunc); ok { callback(queryId)