From 948841de86c77aa2c048a097cbf92df7da8946eb Mon Sep 17 00:00:00 2001 From: "Rose, William" Date: Sun, 2 Feb 2020 22:23:11 -0800 Subject: [PATCH] Added support for Azure Active Directory in DSN and via Connector. --- .gitignore | 5 + README.md | 82 ++- accesstokenconnector.go | 28 +- accesstokenconnector_test.go | 10 +- appveyor.yml | 2 + azuread/fedauth_adal.go | 136 +++++ azuread/fedauth_adal_test.go | 203 +++++++ conn_str.go | 110 +++- doc/how-to-test-azure-ad-authentication.md | 176 ++++++ examples/azuread/ad-service-principal-dsn.jq | 1 + examples/azuread/ad-system-assigned-id-dsn.jq | 1 + examples/azuread/ad-user-assigned-id-dsn.jq | 1 + examples/azuread/ad-user-password-dsn.jq | 1 + examples/azuread/azuread.go | 144 +++++ examples/azuread/environment-settings.jq | 20 + examples/azuread/go.mod | 10 + examples/azuread/testing.tf | 490 +++++++++++++++++ examples/simple/simple.go | 41 +- examples/tvp/tvp.go | 2 + fedauth.go | 103 ++++ go.mod | 2 + go.sum | 26 + log_conn.go | 80 +++ mssql.go | 26 + tds.go | 511 +++++++++++------- tds_test.go | 28 +- token.go | 133 ++++- token_string.go | 66 ++- tvp_go19_db_test.go | 6 +- 29 files changed, 2116 insertions(+), 328 deletions(-) create mode 100644 .gitignore create mode 100644 azuread/fedauth_adal.go create mode 100644 azuread/fedauth_adal_test.go create mode 100644 doc/how-to-test-azure-ad-authentication.md create mode 100644 examples/azuread/ad-service-principal-dsn.jq create mode 100644 examples/azuread/ad-system-assigned-id-dsn.jq create mode 100644 examples/azuread/ad-user-assigned-id-dsn.jq create mode 100644 examples/azuread/ad-user-password-dsn.jq create mode 100644 examples/azuread/azuread.go create mode 100644 examples/azuread/environment-settings.jq create mode 100644 examples/azuread/go.mod create mode 100644 examples/azuread/testing.tf create mode 100644 fedauth.go create mode 100644 log_conn.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..f110691d --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +.terraform +*.tfstate* +*.log +*.swp +*~ diff --git a/README.md b/README.md index 46e4eeab..414d7e6e 100644 --- a/README.md +++ b/README.md @@ -54,10 +54,15 @@ Other supported formats are listed below. * true - Server certificate is not checked. Default is true if encrypt is not specified. If trust server certificate is true, driver accepts any certificate presented by the server and any host name in that certificate. In this mode, TLS is susceptible to man-in-the-middle attacks. This should be used only for testing. * `certificate` - The file that contains the public key certificate of the CA that signed the SQL Server certificate. The specified certificate overrides the go platform specific CA certificates. * `hostNameInCertificate` - Specifies the Common Name (CN) in the server certificate. Default value is the server host. -* `ServerSPN` - The kerberos SPN (Service Principal Name) for the server. Default is MSSQLSvc/host:port. +* `ServerSPN` - The Kerberos SPN (Service Principal Name) for the server. Default is MSSQLSvc/host:port. * `Workstation ID` - The workstation name (default is the host name) * `ApplicationIntent` - Can be given the value `ReadOnly` to initiate a read-only connection to an Availability Group listener. The `database` must be specified when connecting with `Application Intent` set to `ReadOnly`. - +* `FedAuth` - The federated authentication scheme to use. See below for additional setup requirements. + * `ActiveDirectoryApplication` - authenticates using an Azure Active Directory application client ID and client secret or certificate. Set the `user` to `client-ID@tenant-ID` and the `password` to the client secret. If using client certificates, provide the path to the PKCS#12 file containing the certificate and RSA private key in the `ClientCertPath` parameter, and set the `password` to the value needed to open the PKCS#12 file. + * `ActiveDirectoryMSI` - authenticates using the managed service identity (MSI) attached to the VM, or a specific user-assigned identity if a client ID is specified in the `user` field. + * `ActiveDirectoryPassword` - authenticates an Azure Active Directory user account in the form `user@domain.com` with a password. This method is not recommended for general use and does not support multi-factor authentication for accounts. + * `ActiveDirectoryIntegrated` - configures the connection to request Active Directory Integrated authentication. This method is not fully supported: you must also implement a token provider to obtain the token for the currently logged-in user and supply it in the `ActiveDirectoryTokenProvider` field in the `Connector` as described below. + ### The connection string can be specified in one of three formats: @@ -106,25 +111,68 @@ Other supported formats are listed below. * `odbc:server=localhost;user id=sa;password={foo{bar}` // Literal `{`, password is "foo{bar" * `odbc:server=localhost;user id=sa;password={foo}}bar}` // Escaped `} with `}}`, password is "foo}bar" -### Azure Active Directory authentication - preview +### Azure Active Directory authentication + +Azure Active Directory authentication uses temporary authentication tokens to authenticate. +To have the driver obtain these tokens using the +[Active Directory Authentication Library for Go](https://github.com/Azure/go-autorest/tree/master/autorest/adal), +import the Azure AD module in addition to the normal driver module, and configure the +connection string with a `FedAuth` option and supporting information as described above. + +```golang +import ( + "database/sql" + "net/url" -The configuration of functionality might change in the future. + // Import the Azure AD driver module (also imports the regular driver package) + _ "github.com/denisenkom/go-mssqldb/azuread" +) + +func ConnectWithMSI() (*sql.DB, error) { + return sql.Open("sqlserver", "sqlserver://azuresql.database.windows.net?database=yourdb&fedauth=ActiveDirectoryMSI") +} +``` + +As an alternative, you can select the federated authentication library and Active Directory +using the connection string parameters, but then implement your own routine for obtaining +tokens. + +```golang +import ( + "context" + "database/sql" + "net/url" + + // Import the driver + "github.com/denisenkom/go-mssqldb" +) + +func ConnectWithADToken() (*sql.DB, error) { + conn, err := mssql.NewConnector("sqlserver://azuresql.database.windows.net?database=yourdb&fedauth=ActiveDirectoryApplication") + if err != nil { + // handle errors in DSN + } + + conn.SecurityTokenProvider = func(ctx context.Context) (string, error) { + return "the token", nil + } + + return sql.OpenDB(conn), nil +} + +func ConnectWithADIntegrated() (*sql.DB, error) { + conn, err := mssql.NewConnector("sqlserver://azuresq;.database.windows.net?database=yourdb&fedauth=ActiveDirectoryIntegrated") + if err != nil { + // handle errors in DSN + } + + c.ActiveDirectoryTokenProvider = func(ctx context.Context, serverSPN, stsURL string) (string, error) { + return "the token", nil + } -Azure Active Directory (AAD) access tokens are relatively short lived and need to be -valid when a new connection is made. Authentication is supported using a callback func that -provides a fresh and valid token using a connector: -``` golang -conn, err := mssql.NewAccessTokenConnector( - "Server=test.database.windows.net;Database=testdb", - tokenProvider) -if err != nil { - // handle errors in DSN + return sql.OpenDB(conn), nil } -db := sql.OpenDB(conn) ``` -Where `tokenProvider` is a function that returns a fresh access token or an error. None of these statements -actually trigger the retrieval of a token, this happens when the first statment is issued and a connection -is created. ## Executing Stored Procedures diff --git a/accesstokenconnector.go b/accesstokenconnector.go index 8dbe5099..fa6bd381 100644 --- a/accesstokenconnector.go +++ b/accesstokenconnector.go @@ -6,19 +6,8 @@ import ( "context" "database/sql/driver" "errors" - "fmt" ) -var _ driver.Connector = &accessTokenConnector{} - -// accessTokenConnector wraps Connector and injects a -// fresh access token when connecting to the database -type accessTokenConnector struct { - Connector - - accessTokenProvider func() (string, error) -} - // NewAccessTokenConnector creates a new connector from a DSN and a token provider. // The token provider func will be called when a new connection is requested and should return a valid access token. // The returned connector may be used with sql.OpenDB. @@ -32,20 +21,9 @@ func NewAccessTokenConnector(dsn string, tokenProvider func() (string, error)) ( return nil, err } - c := &accessTokenConnector{ - Connector: *conn, - accessTokenProvider: tokenProvider, - } - return c, nil -} - -// Connect returns a new database connection -func (c *accessTokenConnector) Connect(ctx context.Context) (driver.Conn, error) { - var err error - c.Connector.params.fedAuthAccessToken, err = c.accessTokenProvider() - if err != nil { - return nil, fmt.Errorf("mssql: error retrieving access token: %+v", err) + conn.SecurityTokenProvider = func(ctx context.Context) (string, error) { + return tokenProvider() } - return c.Connector.Connect(ctx) + return conn, nil } diff --git a/accesstokenconnector_test.go b/accesstokenconnector_test.go index 826dedba..5865aa37 100644 --- a/accesstokenconnector_test.go +++ b/accesstokenconnector_test.go @@ -30,21 +30,21 @@ func TestNewAccessTokenConnector(t *testing.T) { dsn: dsn, tokenProvider: tp}, want: func(c driver.Connector) error { - tc, ok := c.(*accessTokenConnector) + tc, ok := c.(*Connector) if !ok { - return fmt.Errorf("Expected driver to be of type *accessTokenConnector, but got %T", c) + return fmt.Errorf("Expected driver to be of type *Connector, but got %T", c) } - p := tc.Connector.params + p := tc.params if p.database != "db" { return fmt.Errorf("expected params.database=db, but got %v", p.database) } if p.host != "server.database.windows.net" { return fmt.Errorf("expected params.host=server.database.windows.net, but got %v", p.host) } - if tc.accessTokenProvider == nil { + if tc.SecurityTokenProvider == nil { return fmt.Errorf("Expected tokenProvider to not be nil") } - t, err := tc.accessTokenProvider() + t, err := tc.SecurityTokenProvider(context.TODO()) if t != "token" || err != nil { return fmt.Errorf("Unexpected results from tokenProvider: %v, %v", t, err) } diff --git a/appveyor.yml b/appveyor.yml index c4d2bb06..51de3c26 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -30,6 +30,8 @@ install: - go version - go env - go get -u github.com/golang-sql/civil + - go get -u golang.org/x/crypto/pkcs12 + - go get -u github.com/Azure/go-autorest/autorest/adal build_script: - go build diff --git a/azuread/fedauth_adal.go b/azuread/fedauth_adal.go new file mode 100644 index 00000000..b23eb8e9 --- /dev/null +++ b/azuread/fedauth_adal.go @@ -0,0 +1,136 @@ +package azuread + +import ( + "context" + "crypto/rsa" + "crypto/x509" + "fmt" + + "github.com/Azure/go-autorest/autorest/adal" + mssql "github.com/denisenkom/go-mssqldb" +) + +type azureFedAuthConfigurer struct{} + +func init() { + mssql.SetFederatedAuthenticationConfigurer(&azureFedAuthConfigurer{}) +} + +// When the security token library is used, the token is obtained without input +// from the server, so the AD endpoint and Azure SQL resource URI are provided +// from the constants below. +const ( + // activeDirectoryEndpoint is the security token service URL to use when + // the server does not provide the URL. + activeDirectoryEndpoint = "https://login.microsoftonline.com/" + + // azureSQLResource is the AD resource to use when the server does not + // provide the resource. + azureSQLResource = "https://database.windows.net/" + + // driverClientID is the AD client ID to use when performing a username + // and password login. + driverClientID = "7f98cb04-cd1e-40df-9140-3bf7e2cea4db" +) + +func retrieveToken(ctx context.Context, token *adal.ServicePrincipalToken) (string, error) { + err := token.RefreshWithContext(ctx) + if err != nil { + err = fmt.Errorf("Failed to refresh token: %v", err) + return "", err + } + + return token.Token().AccessToken, nil +} + +func (az *azureFedAuthConfigurer) SecurityTokenProviderFromCertificate(clientID, tenantID string, certificate *x509.Certificate, rsaPrivateKey *rsa.PrivateKey) mssql.SecurityTokenProvider { + return func(ctx context.Context) (string, error) { + // The activeDirectoryEndpoint URL is used as a base against which the + // tenant ID is resolved. + oauthConfig, err := adal.NewOAuthConfig(activeDirectoryEndpoint, tenantID) + if err != nil { + err = fmt.Errorf("Failed to obtain OAuth configuration for endpoint %s and tenant %s: %v", + activeDirectoryEndpoint, tenantID, err) + return "", err + } + + token, err := adal.NewServicePrincipalTokenFromCertificate(*oauthConfig, clientID, certificate, rsaPrivateKey, azureSQLResource) + if err != nil { + err = fmt.Errorf("Failed to obtain service principal token for client id %s in tenant %s: %v", clientID, tenantID, err) + return "", err + } + + return retrieveToken(ctx, token) + } +} + +func (az *azureFedAuthConfigurer) SecurityTokenProviderFromSecret(clientID, tenantID, clientSecret string) mssql.SecurityTokenProvider { + return func(ctx context.Context) (string, error) { + // The activeDirectoryEndpoint URL is used as a base against which the + // tenant ID is resolved. + oauthConfig, err := adal.NewOAuthConfig(activeDirectoryEndpoint, tenantID) + if err != nil { + err = fmt.Errorf("Failed to obtain OAuth configuration for endpoint %s and tenant %s: %v", + activeDirectoryEndpoint, tenantID, err) + return "", err + } + + token, err := adal.NewServicePrincipalToken(*oauthConfig, clientID, clientSecret, azureSQLResource) + + if err != nil { + err = fmt.Errorf("Failed to obtain service principal token for client id %s in tenant %s: %v", clientID, tenantID, err) + return "", err + } + + return retrieveToken(ctx, token) + } +} + +func (az *azureFedAuthConfigurer) ActiveDirectoryTokenProviderFromPassword(user, password string) mssql.ActiveDirectoryTokenProvider { + return func(ctx context.Context, serverSPN, stsURL string) (string, error) { + // The activeDirectoryEndpoint URL is used as a base against which the + // STS URL is resolved. However, the STS URL is normally absolute and + // the activeDirectoryEndpoint URL is completely ignored. + oauthConfig, err := adal.NewOAuthConfig(activeDirectoryEndpoint, stsURL) + if err != nil { + err = fmt.Errorf("Failed to obtain OAuth configuration for endpoint %s and tenant %s: %v", + activeDirectoryEndpoint, stsURL, err) + return "", err + } + + token, err := adal.NewServicePrincipalTokenFromUsernamePassword(*oauthConfig, driverClientID, user, password, serverSPN) + + if err != nil { + err = fmt.Errorf("Failed to obtain token for user %s for resource %s from service %s: %v", user, serverSPN, stsURL, err) + return "", err + } + + return retrieveToken(ctx, token) + } +} + +func (az *azureFedAuthConfigurer) ActiveDirectoryTokenProviderFromIdentity(clientID string) mssql.ActiveDirectoryTokenProvider { + return func(ctx context.Context, serverSPN, stsURL string) (string, error) { + msiEndpoint, err := adal.GetMSIEndpoint() + if err != nil { + return "", err + } + + var token *adal.ServicePrincipalToken + var access string + if clientID == "" { + access = "system identity" + token, err = adal.NewServicePrincipalTokenFromMSI(msiEndpoint, serverSPN) + } else { + access = "user-assigned identity " + clientID + token, err = adal.NewServicePrincipalTokenFromMSIWithUserAssignedID(msiEndpoint, serverSPN, clientID) + } + + if err != nil { + err = fmt.Errorf("Failed to obtain token for %s for resource %s from service %s: %v", access, serverSPN, stsURL, err) + return "", err + } + + return retrieveToken(ctx, token) + } +} diff --git a/azuread/fedauth_adal_test.go b/azuread/fedauth_adal_test.go new file mode 100644 index 00000000..b9e0d5e7 --- /dev/null +++ b/azuread/fedauth_adal_test.go @@ -0,0 +1,203 @@ +package azuread + +import ( + "context" + "database/sql" + "net/url" + "os" + "strings" + "testing" + + mssql "github.com/denisenkom/go-mssqldb" +) + +type testLogger struct { + t *testing.T +} + +func (l testLogger) Printf(format string, v ...interface{}) { + l.t.Logf(format, v...) +} + +func (l testLogger) Println(v ...interface{}) { + l.t.Log(v...) +} + +func checkAzureSQLEnvironment(fedAuth string, t *testing.T) (*url.URL, string) { + u := &url.URL{ + Scheme: "sqlserver", + Host: os.Getenv("SQL_SERVER"), + } + + if u.Host == "" { + t.Skip("Azure SQL Server name not provided in SQL_SERVER environment variable") + } + + database := os.Getenv("SQL_DATABASE") + if database == "" { + t.Skip("Azure SQL database name not provided in SQL_DATABASE environment variable") + } + + tenantID := os.Getenv("AZURE_TENANT_ID") + if tenantID == "" { + t.Skip("Azure tenant ID not provided in AZURE_TENANT_ID environment variable") + } + + query := u.Query() + + query.Add("database", database) + query.Add("encrypt", "true") + query.Add("fedauth", fedAuth) + + u.RawQuery = query.Encode() + + return u, tenantID +} + +func checkFedAuthUserPassword(t *testing.T) *url.URL { + u, _ := checkAzureSQLEnvironment("ActiveDirectoryPassword", t) + + username := os.Getenv("SQL_AD_ADMIN_USER") + password := os.Getenv("SQL_AD_ADMIN_PASSWORD") + + if username == "" || password == "" { + t.Skip("Username and password login requires SQL_AD_ADMIN_USER and SQL_AD_ADMIN_PASSWORD environment variables") + } + + u.User = url.UserPassword(username, password) + + return u +} + +func checkFedAuthAppPassword(t *testing.T) *url.URL { + u, tenantID := checkAzureSQLEnvironment("ActiveDirectoryApplication", t) + + appClientID := os.Getenv("APP_SP_CLIENT_ID") + appPassword := os.Getenv("APP_SP_CLIENT_SECRET") + + if appClientID == "" || appPassword == "" { + t.Skip("Application (service principal) login requires APP_SP_CLIENT_ID and APP_SP_CLIENT_SECRET environment variables") + } + + u.User = url.UserPassword(appClientID+"@"+tenantID, appPassword) + + return u +} + +func checkFedAuthAppCertPath(t *testing.T) *url.URL { + u := checkFedAuthAppPassword(t) + + appCertPath := os.Getenv("APP_SP_CLIENT_CERT") + if appCertPath == "" { + t.Skip("Application (service principal) certificate login requires APP_SP_CLIENT_CERT with path to certificate") + } + + query := u.Query() + query.Add("clientcertpath", appCertPath) + u.RawQuery = query.Encode() + + return u +} + +func checkFedAuthVMSystemID(t *testing.T) (*url.URL, string) { + u, tenantID := checkAzureSQLEnvironment("ActiveDirectoryMSI", t) + + vmClientID := os.Getenv("VM_CLIENT_ID") + if vmClientID == "" { + t.Skip("System-assigned identity login test requires VM_CLIENT_ID environment variable") + } + + return u, vmClientID + "@" + tenantID +} + +func checkFedAuthVMUserAssignedID(t *testing.T) (*url.URL, string) { + u, tenantID := checkAzureSQLEnvironment("ActiveDirectoryMSI", t) + + uaClientID := os.Getenv("UA_CLIENT_ID") + if uaClientID == "" { + t.Skip("User-assigned identity login test requires UA_CLIENT_ID environment variable") + } + + u.User = url.User(uaClientID) + + return u, uaClientID + "@" + tenantID +} + +func checkLoggedInUser(expected string, u *url.URL, t *testing.T) { + db, err := sql.Open("sqlserver", u.String()) + if err != nil { + t.Fatalf("Failed to open URL %v: %v", u, err) + } + + defer db.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + sql := "SELECT SUSER_NAME()" + + stmt, err := db.PrepareContext(ctx, sql) + if err != nil { + t.Fatalf("Failed to prepare query %s: %v", sql, err) + } + + defer stmt.Close() + + rows, err := stmt.QueryContext(ctx) + if err != nil { + t.Fatalf("Failed to fetch query result for %s: %v", sql, err) + } + + defer rows.Close() + + var username string + if !rows.Next() { + t.Fatalf("Empty result set for query %s", sql) + } + + err = rows.Scan(&username) + if err != nil { + t.Fatalf("Failed to fetch first row for %s: %v", sql, err) + } + + if !strings.EqualFold(username, expected) { + t.Fatalf("Expected username %s: actual: %s", expected, username) + } + + t.Logf("Logged in username %s matches expected %s", username, expected) +} + +func TestFedAuthWithUserAndPassword(t *testing.T) { + mssql.SetLogger(testLogger{t}) + u := checkFedAuthUserPassword(t) + + checkLoggedInUser(u.User.Username(), u, t) +} + +func TestFedAuthWithApplicationUsingPassword(t *testing.T) { + mssql.SetLogger(testLogger{t}) + u := checkFedAuthAppPassword(t) + + checkLoggedInUser(u.User.Username(), u, t) +} + +func TestFedAuthWithApplicationUsingCertificate(t *testing.T) { + mssql.SetLogger(testLogger{t}) + u := checkFedAuthAppCertPath(t) + + checkLoggedInUser(u.User.Username(), u, t) +} + +func TestFedAuthWithSystemAssignedIdentity(t *testing.T) { + u, vmName := checkFedAuthVMSystemID(t) + mssql.SetLogger(testLogger{t}) + + checkLoggedInUser(vmName, u, t) +} + +func TestFedAuthWithUserAssignedIdentity(t *testing.T) { + mssql.SetLogger(testLogger{t}) + u, uaName := checkFedAuthVMUserAssignedID(t) + + checkLoggedInUser(uaName, u, t) +} diff --git a/conn_str.go b/conn_str.go index 26ac50f3..37387e38 100644 --- a/conn_str.go +++ b/conn_str.go @@ -1,6 +1,7 @@ package mssql import ( + "errors" "fmt" "net" "net/url" @@ -13,31 +14,44 @@ import ( const defaultServerPort = 1433 +const ( + fedAuthActiveDirectoryPassword = "ActiveDirectoryPassword" + fedAuthActiveDirectoryIntegrated = "ActiveDirectoryIntegrated" + fedAuthActiveDirectoryMSI = "ActiveDirectoryMSI" + fedAuthActiveDirectoryApplication = "ActiveDirectoryApplication" +) + type connectParams struct { - logFlags uint64 - port uint64 - host string - instance string - database string - user string - password string - dial_timeout time.Duration - conn_timeout time.Duration - keepAlive time.Duration - encrypt bool - disableEncryption bool - trustServerCertificate bool - certificate string - hostInCertificate string - hostInCertificateProvided bool - serverSPN string - workstation string - appname string - typeFlags uint8 - failOverPartner string - failOverPort uint64 - packetSize uint16 - fedAuthAccessToken string + logFlags uint64 + port uint64 + host string + instance string + database string + user string + password string + dial_timeout time.Duration + conn_timeout time.Duration + keepAlive time.Duration + encrypt bool + disableEncryption bool + trustServerCertificate bool + certificate string + hostInCertificate string + hostInCertificateProvided bool + serverSPN string + workstation string + appname string + typeFlags uint8 + failOverPartner string + failOverPort uint64 + packetSize uint16 + fedAuthLibrary byte + fedAuthADALWorkflow byte + aadTenantID string + aadClientCertPath string + securityTokenProvider SecurityTokenProvider + activeDirectoryTokenProvider ActiveDirectoryTokenProvider + tlsKeyLogFile string } func parseConnectParams(dsn string) (connectParams, error) { @@ -230,6 +244,54 @@ func parseConnectParams(dsn string) (connectParams, error) { } } + p.fedAuthLibrary = fedAuthLibraryReserved + fedAuth, ok := params["fedauth"] + if ok { + switch { + case strings.EqualFold(fedAuth, fedAuthActiveDirectoryPassword): + p.fedAuthLibrary = fedAuthLibraryADAL + p.fedAuthADALWorkflow = fedAuthADALWorkflowPassword + case strings.EqualFold(fedAuth, fedAuthActiveDirectoryIntegrated): + // Active Directory Integrated authentication is not fully supported: + // you can only use this by also implementing an a token provider + // and supplying it via ActiveDirectoryTokenProvider in the Connection. + p.fedAuthLibrary = fedAuthLibraryADAL + p.fedAuthADALWorkflow = fedAuthADALWorkflowIntegrated + case strings.EqualFold(fedAuth, fedAuthActiveDirectoryMSI): + // When using MSI, to request a specific client ID or user-assigned identity, + // provide the ID as the username. + p.fedAuthLibrary = fedAuthLibraryADAL + p.fedAuthADALWorkflow = fedAuthADALWorkflowMSI + case strings.EqualFold(fedAuth, fedAuthActiveDirectoryApplication): + p.fedAuthLibrary = fedAuthLibrarySecurityToken + p.aadClientCertPath = params["clientcertpath"] + + // Split the user name into client id and tenant id at the @ symbol + at := strings.IndexRune(p.user, '@') + if at < 1 || at >= (len(p.user)-1) { + f := "Expecting user id to be clientID@tenantID: found '%s'" + return p, fmt.Errorf(f, p.user) + } + + p.aadTenantID = p.user[at+1:] + p.user = p.user[0:at] + default: + f := "Invalid federated authentication type '%s': expected %s, %s or %s" + return p, fmt.Errorf(f, fedAuth, fedAuthActiveDirectoryPassword, fedAuthActiveDirectoryMSI, + fedAuthActiveDirectoryIntegrated, fedAuthActiveDirectoryApplication) + } + + if p.disableEncryption { + f := "Encryption must not be disabled for federated authentication: encrypt='%s'" + return p, fmt.Errorf(f, encrypt) + } + } + + p.tlsKeyLogFile, ok = params["tls key log file"] + if ok && p.tlsKeyLogFile != "" && p.disableEncryption { + return p, errors.New("Cannot set tlsKeyLogFile when encryption is disabled") + } + return p, nil } diff --git a/doc/how-to-test-azure-ad-authentication.md b/doc/how-to-test-azure-ad-authentication.md new file mode 100644 index 00000000..e7564b61 --- /dev/null +++ b/doc/how-to-test-azure-ad-authentication.md @@ -0,0 +1,176 @@ +# How to test Azure AD authentication + +To test Azure AD authentication requires an Azure SQL server configured with an +[Active Directory administrator](https://docs.microsoft.com/en-us/azure/sql-database/sql-database-aad-authentication-configure). +To test managed identity authentication, an Azure virtual machine configured with +[system-assigned and/or user-assigned identities](https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/qs-configure-portal-windows-vm) +is also required. + +The necessary resources can be set up through any means including the +[Azure Portal](https://portal.azure.com/), the Azure CLI, the Azure PowerShell cmdlets or +[Terraform](https://terraform.io/). To support these instructions, use the Terraform script at +[examples/azuread/testing.tf](../examples/azuread/testing.tf). + +## Create Azure infrastructure + +Download [Terraform](https://terraform.io/) to a location on your PATH. + +Log in to Azure using the Azure CLI. + +```console +you@workstation:~$ az login +you@workstation:~$ az account show +``` + +If your Azure account has access to multiple subscriptions, use +`az account set --subscription ` to choose the correct one. You will need to have at +least Contributor access to the portal and permissions in Azure Active Directory to create users +and grants. + +Check out this source repository (if you haven't already), change to the `examples/azuread` +directory and run Terraform: + +```console +you@workstation:~$ git clone -b azure-auth https://github.com/wrosenuance/go-mssqldb.git +you@workstation:~$ cd go-mssqldb/examples/azuread +you@workstation:azuread$ terraform init +you@workstation:azuread$ terraform apply +``` + +This will create an Azure resource group, a SQL server with a database, a virtual machine with a +system-assigned identity and user-assigned identity. Resources are named based on a random +prefix: to specify the prefix, use `terraform apply -var prefix=`. + +Upon successful completion, Terraform will display some key details of the infrastructure that has + been created. This includes the SSH key to access the VM, the administrator account and password + for the Azure SQL server, and all the relevant resource names. + +Save the settings to a JSON file: + +```console +you@workstation:azuread$ terraform output -json > settings.json +``` + +Save the SSH private key to a file: + +```console +you@workstation:azuread$ terraform output vm_user_ssh_private_key > ssh-identity +``` + +Copy the `settings.json` to the new VM: + +```console +you@workstation:azuread$ scp -i ssh-identity settings.json "$(terraform output vm_admin_name)@$(terraform output vm_ip_address):" +``` + +## Set up Azure Virtual Machine for testing + +SSH to the new VM to continue setup: + +```console +you@workstation:azuread$ ssh -i ssh-identity "$(terraform output vm_admin_name)@$(terraform output vm_ip_address)" +``` + +Once on the VM, update the system and install some basic packages: + +```console +azureuser@azure-vm:~$ sudo apt update -y +azureuser@azure-vm:~$ sudo apt upgrade -y +azureuser@azure-vm:~$ sudo apt install -y git openssl jq build-essential +azureuser@azure-vm:~$ sudo snap install go --classic +``` + +Install the Azure CLI using the script as shown below, or follow the +[manual install instructions](https://docs.microsoft.com/en-us/cli/azure/install-azure-cli-apt): + +```console +azureuser@azure-vm:~$ curl -sL https://aka.ms/InstallAzureCLIDeb | sudo bash +``` + +## Generate service principal certificate file + +Log in to Azure on the VM and set the subscription: + +```console +azureuser@azure-vm:~$ az login +azureuser@azure-vm:~$ az account set --subscription "$(jq -r '.subscription_id.value' settings.json)" +``` + +Use OpenSSL to create a new certificate and key in PEM format, using the : + +```console +azureuser@azure-vm:~$ openssl rand -writerand ~/.rnd +azureuser@azure-vm:~$ openssl req -x509 -nodes -newkey rsa:4096 -keyout client.key -out client.crt \ + -subj "/C=US/ST=MA/L=Boston/O=Global Security/OU=IT Department/CN=AD-SP" +azureuser@azure-vm:~$ openssl pkcs12 -export -out client.p12 -inkey client.key -in client.crt \ + -passout "pass:$(jq -r '.app_sp_client_secret.value' settings.json)" +azureuser@azure-vm:~$ export APP_SP_CLIENT_CERT="$PWD/client.p12" +``` + +Use the Azure CLI to add the client certificate to the application service principal: + +```console +azureuser@azure-vm:~$ az ad sp credential reset --append --cert @client.crt \ + --name "$(jq -r '.app_sp_client_id.value' settings.json)" +``` + +## Build source code and authorize users in database + +Clone this repository, build and run the `examples/azuread` helper that verifies the database +exists and sets up access for the system-assigned and user-assigned identities. + +```console +azureuser@azure-vm:~$ git clone -b azure-auth https://github.com/wrosenuance/go-mssqldb.git +azureuser@azure-vm:~$ cd go-mssqldb +azureuser@azure-vm:go-mssqldb$ go generate ./... +azureuser@azure-vm:go-mssqldb$ (cd ./examples/azuread; go build -o ../../azuread-example .) +azureuser@azure-vm:go-mssqldb$ eval "$(jq -r -f examples/azuread/environment-settings.jq ../settings.json)" +azureuser@azure-vm:go-mssqldb$ ./azuread-example -fedauth ActiveDirectoryPassword +``` + +For some basic connectivity tests, use the `examples/simple` helper. Run these commands on the +Azure VM so that identity authentication is possible. + +```console +azureuser@azure-vm:go-mssqldb$ go build -o simple ./examples/simple +azureuser@azure-vm:go-mssqldb$ jq -r -f examples/azuread/ad-user-password-dsn.jq ../settings.json | + xargs ./simple -debug -dsn +azureuser@azure-vm:go-mssqldb$ jq -r -f examples/azuread/ad-service-principal-dsn.jq ../settings.json | + xargs ./simple -debug -dsn +azureuser@azure-vm:go-mssqldb$ jq -r -f examples/azuread/ad-system-assigned-id-dsn.jq ../settings.json | + xargs ./simple -debug -dsn +azureuser@azure-vm:go-mssqldb$ jq -r -f examples/azuread/ad-user-assigned-id-dsn.jq ../settings.json | + xargs ./simple -debug -dsn +``` + +## Running the integration tests + +Now that your environment is configured, you can run `go test`: + +```console +azureuser@azure-vm:go-mssqldb$ export SQLSERVER_DSN="$(jq -r -f examples/azuread/ad-system-assigned-id-dsn.jq ../settings.json)" +azureuser@azure-vm:go-mssqldb$ go test -coverprofile=coverage.out ./... +``` + +## Tear down environment + +After you complete your testing, use Terraform to destroy the infrastructure you created. + +```console +you@workstation:azuread$ terraform destroy +``` + +## Troubleshooting + +After Terraform runs you should be able to see resources that were created in the +[Azure Portal](https://portal.azure.com/). + +If the Azure SQL server is successfully created you can connect to it using the AD admin user +and password in SSMS. SSMS will prompt you to create firewall rules if they are missing. You +can read the AD admin user and password from the `settings.json`, or run: + +```console +you@workstation:azuread$ terraform output sql_ad_admin_user +you@workstation:azuread$ terraform output sql_ad_admin_password +``` + diff --git a/examples/azuread/ad-service-principal-dsn.jq b/examples/azuread/ad-service-principal-dsn.jq new file mode 100644 index 00000000..b2426037 --- /dev/null +++ b/examples/azuread/ad-service-principal-dsn.jq @@ -0,0 +1 @@ +@uri "sqlserver://\(.app_sp_client_id.value)%40\(.tenant_id.value):\(.app_sp_client_secret.value)@\(.sql_server_fqdn.value)?database=\(.sql_database_name.value)&encrypt=true&fedauth=ActiveDirectoryApplication" \ No newline at end of file diff --git a/examples/azuread/ad-system-assigned-id-dsn.jq b/examples/azuread/ad-system-assigned-id-dsn.jq new file mode 100644 index 00000000..288f2b7c --- /dev/null +++ b/examples/azuread/ad-system-assigned-id-dsn.jq @@ -0,0 +1 @@ +@uri "sqlserver://\(.sql_server_fqdn.value)?database=\(.sql_database_name.value)&encrypt=true&fedauth=ActiveDirectoryMSI" \ No newline at end of file diff --git a/examples/azuread/ad-user-assigned-id-dsn.jq b/examples/azuread/ad-user-assigned-id-dsn.jq new file mode 100644 index 00000000..df31d09d --- /dev/null +++ b/examples/azuread/ad-user-assigned-id-dsn.jq @@ -0,0 +1 @@ +@uri "sqlserver://\(.user_assigned_identity_client_id.value)@\(.sql_server_fqdn.value)?database=\(.sql_database_name.value)&encrypt=true&fedauth=ActiveDirectoryMSI" \ No newline at end of file diff --git a/examples/azuread/ad-user-password-dsn.jq b/examples/azuread/ad-user-password-dsn.jq new file mode 100644 index 00000000..beebc5e1 --- /dev/null +++ b/examples/azuread/ad-user-password-dsn.jq @@ -0,0 +1 @@ +@uri "sqlserver://\(.sql_ad_admin_user.value):\(.sql_ad_admin_password.value)@\(.sql_server_fqdn.value)?database=\(.sql_database_name.value)&encrypt=true&fedauth=ActiveDirectoryPassword" \ No newline at end of file diff --git a/examples/azuread/azuread.go b/examples/azuread/azuread.go new file mode 100644 index 00000000..e8e91cf9 --- /dev/null +++ b/examples/azuread/azuread.go @@ -0,0 +1,144 @@ +package main + +import ( + "database/sql" + "flag" + "fmt" + "log" + "net/url" + "os" + "strings" + "time" + + _ "github.com/denisenkom/go-mssqldb/azuread" +) + +var ( + debug = flag.Bool("debug", false, "enable debugging") + server = flag.String("server", os.Getenv("SQL_SERVER"), "the database server name") + port = flag.Int("port", 1433, "the database port") + database = flag.String("database", os.Getenv("SQL_DATABASE"), "the database name") + user = flag.String("user", os.Getenv("SQL_AD_ADMIN_USER"), "the AD administrator user name") + password = flag.String("password", os.Getenv("SQL_AD_ADMIN_PASSWORD"), "the AD administrator password") + fedauth = flag.String("fedauth", "ActiveDirectoryPassword", "the federated authentication scheme to use") + appName = flag.String("app-name", os.Getenv("APP_NAME"), "the application name to authorize") + vmName = flag.String("vm-name", os.Getenv("VM_NAME"), "the system identity name to authorize for this VM") + uaName = flag.String("ua-name", os.Getenv("UA_NAME"), "the user assigned identity name to authorize for this VM") +) + +func createConnStr(database string) string { + connString := fmt.Sprintf("sqlserver://%s:%s@%s:%d?encrypt=true", + url.QueryEscape(*user), url.QueryEscape(*password), + url.QueryEscape(*server), *port) + + if database != "" && database != "master" { + connString = connString + "&database=" + url.QueryEscape(database) + } + + if *fedauth != "" { + connString = connString + "&fedauth=" + url.QueryEscape(*fedauth) + } + + if *debug { + connString = connString + "&log=127" + } + + return connString +} + +func createDatabaseIfNotExists() error { + // Check database exists by connecting to master on the Azure SQL server + connString := createConnStr("master") + + log.Printf("Open: %s\n", connString) + + conn, err := sql.Open("sqlserver", connString) + if err != nil { + return err + } + + defer conn.Close() + + if err = conn.Ping(); err != nil { + return err + } + + quoted := strings.Replace(*database, "]", "]]", -1) + sql := "IF NOT EXISTS (SELECT 1 FROM sys.databases WHERE name = @p1)\n CREATE DATABASE [" + quoted + "] ( SERVICE_OBJECTIVE = 'S0' )" + log.Printf("Exec: @p1 = '%s'\n%s\n", *database, sql) + _, err = conn.Exec(sql, *database) + + return err +} + +func addExternalUserIfNotExists(user string) error { + connString := createConnStr(*database) + + log.Printf("Open: %s\n", connString) + + var conn *sql.DB + var err error + + for retry := 0; retry < 8; retry++ { + conn, err = sql.Open("sqlserver", connString) + if err == nil { + if err = conn.Ping(); err == nil { + break + } + } + log.Printf("Connection failed: %v", err) + log.Println("Retry in 15 seconds") + time.Sleep(15 * time.Second) + } + if err != nil { + log.Printf("Connection failed: %v", err) + log.Println("No further retries will be attempted") + return err + } + + defer conn.Close() + + quoted := strings.Replace(user, "]", "]]", -1) + sql := "IF NOT EXISTS (SELECT 1 FROM sys.database_principals WHERE name = @p1)\n CREATE USER [" + quoted + "] FROM EXTERNAL PROVIDER" + log.Printf("Exec: @p1 = '%s'\n%s\n", user, sql) + _, err = conn.Exec(sql, user) + if err != nil { + return err + } + + sql = "IF IS_ROLEMEMBER('db_owner', @p1) = 0\n ALTER ROLE [db_owner] ADD MEMBER [" + quoted + "]" + log.Printf("Exec: @p1 = '%s'\n%s\n", user, sql) + _, err = conn.Exec(sql, user) + + return err +} + +func main() { + flag.Parse() + + err := createDatabaseIfNotExists() + if err != nil { + log.Fatalf("Unable to create database [%s]: %v", *database, err) + } + + if *vmName != "" { + err = addExternalUserIfNotExists(*vmName) + if err != nil { + log.Fatalf("Unable to create user for system-assigned identity [%s]: %v", *vmName, err) + } + } + + if *appName != "" { + err = addExternalUserIfNotExists(*appName) + if err != nil { + log.Fatalf("Unable to create user for application identity [%s]: %v", *appName, err) + } + } + + if *uaName != "" { + err = addExternalUserIfNotExists(*uaName) + if err != nil { + log.Fatalf("Unable to create user for user-assigned identity [%s]: %v", *uaName, err) + } + } +} diff --git a/examples/azuread/environment-settings.jq b/examples/azuread/environment-settings.jq new file mode 100644 index 00000000..a8c9192a --- /dev/null +++ b/examples/azuread/environment-settings.jq @@ -0,0 +1,20 @@ +# Convert Terraform settings to shell environment exports. +[ + "set -a", + "SQL_SERVER=" + (.sql_server_fqdn.value | @sh), + "SQL_ADMIN_USER=" + (.sql_admin_user.value | @sh), + "SQL_ADMIN_PASSWORD=" + (.sql_admin_password.value | @sh), + "SQL_AD_ADMIN_USER=" + (.sql_ad_admin_user.value | @sh), + "SQL_AD_ADMIN_PASSWORD=" + (.sql_ad_admin_password.value | @sh), + "APP_SP_CLIENT_ID=" + (.app_sp_client_id.value | @sh), + "APP_SP_CLIENT_SECRET=" + (.app_sp_client_secret.value | @sh), + "SQL_DATABASE=" + (.sql_database_name.value | @sh), + "APP_NAME=" + (.app_name.value | @sh), + "VM_NAME=" + (.vm_name.value | @sh), + "VM_CLIENT_ID=" + (.vm_client_id.value | @sh), + "UA_NAME=" + (.user_assigned_identity_name.value | @sh), + "UA_CLIENT_ID=" + (.user_assigned_identity_client_id.value | @sh), + "AZURE_SUBSCRIPTION_ID=" + (.subscription_id.value | @sh), + "AZURE_TENANT_ID=" + (.tenant_id.value | @sh), + "set +a" +] | map([.]) | .[] | @tsv diff --git a/examples/azuread/go.mod b/examples/azuread/go.mod new file mode 100644 index 00000000..217d12b0 --- /dev/null +++ b/examples/azuread/go.mod @@ -0,0 +1,10 @@ +module github.com/denisenkom/go-mssqldb/examples/azuread + +go 1.13 + +require ( + github.com/Azure/go-autorest/autorest/adal v0.8.1 + github.com/denisenkom/go-mssqldb v0.0.0-20191128021309-1d7a30a10f73 +) + +replace github.com/denisenkom/go-mssqldb => ../.. \ No newline at end of file diff --git a/examples/azuread/testing.tf b/examples/azuread/testing.tf new file mode 100644 index 00000000..cef943d8 --- /dev/null +++ b/examples/azuread/testing.tf @@ -0,0 +1,490 @@ +# +# Terraform setup for Azure SQL with Azure Active Directory authentication +# + +# +# Set up Terraform provider versions +# +provider "azuread" { + version = "~> 0.7" +} + +provider "azurerm" { + version = "~> 1.36" +} + +provider "http" { + version = "~> 1.1" +} + +provider "random" { + version = "~> 2.2" +} + +provider "tls" { + version = "~> 2.1" +} + +# +# Variables +# +# These variables allow limited overrides to control the resource creation. +# To specify, run terraform apply -var name1=value1 [-var name2=value2]... +# E.g. terraform apply -var prefix=my-stuff +# will use "my-stuff" in place of the randomly generated ID that is used by default. +# +variable "prefix" { + description = "Prefix for Azure resource names" + type = string + default = "" +} + +variable "location" { + description = "Azure location for resources" + type = string + default = "East US" +} + +variable "vm_admin_name" { + description = "Name of administrative user on virtual machine" + type = string + default = "azureuser" +} + +variable "ssh_key" { + description = "Path to RSA SSH private key (unencrypted)" + type = string + default = "~/.ssh/id_rsa" +} + +variable "workstation_ip" { + description = "IP address of this workstation to add to SQL server firewall rules" + type = string + default = "" +} + +# +# If the prefix is not specified via the variable, a sixteen character alphanumeric suffix is +# generated and then the prefix is set to "go-mssql-test-" + +# +resource "random_string" "random_prefix" { + length = 16 + lower = true + number = true + upper = false + special = false +} + +# +# Set up a local variable to capture the prefix to use - either the user-specified from the +# variable, or else the generated name using the random string above. +# +# Some resource names (e.g. SQL server) are more restricted than others - e.g. hyphens are +# not permitted - so we create a restricted name prefix as well as a regular name prefix. +# +locals { + regular_name_prefix = var.prefix != "" ? var.prefix : "go-mssql-test-${random_string.random_prefix.result}" + restricted_name_prefix = var.prefix != "" ? lower(replace(var.prefix, "/[^A-Za-z0-9]/", "")) : "gomssqltest${random_string.random_prefix.result}" +} + +# +# SSH Key - generate if not available at the file named in the variable. +# Terraform will complain if var.ssh_key is empty as this is interpreted as referring to the +# current working directory, and that is not a file. Instead, if you want to avoid using an +# existing SSH key, make it a literal "no" or some other string that is not an existing file or +# directory. +# +data "tls_public_key" "file_ssh_key" { + count = fileexists(var.ssh_key) ? 1 : 0 + private_key_pem = fileexists(var.ssh_key) ? file(var.ssh_key) : "" +} + +resource "tls_private_key" "rand_ssh_key" { + algorithm = "ECDSA" +} + +locals { + private_key_pem = fileexists(var.ssh_key) ? data.tls_public_key.file_ssh_key.0.private_key_pem : tls_private_key.rand_ssh_key.private_key_pem + public_key_pem = fileexists(var.ssh_key) ? data.tls_public_key.file_ssh_key.0.public_key_pem : tls_private_key.rand_ssh_key.public_key_pem + public_key_openssh = fileexists(var.ssh_key) ? data.tls_public_key.file_ssh_key.0.public_key_openssh : tls_private_key.rand_ssh_key.public_key_openssh +} + +# +# Retrieve tenant, subscription and default domain information based on the current Azure login. +# +data "azurerm_client_config" "current" { +} + +data "azurerm_subscription" "current" { +} + +data "azuread_domains" "current" { + only_default = "true" +} + +# +# Use ipify.org to determine workstation IP if not provided. +# If this guesses incorrectly, specify your workstation IP with -var worstation_ip= +# when you run terraform apply. +# +data "http" "workstation_ip" { + url = "https://api.ipify.org/" +} + +locals { + workstation_ip = var.workstation_ip != "" ? var.workstation_ip : chomp(data.http.workstation_ip.body) +} + +# +# Set up the Azure resource group for all the test resources. +# +resource "azurerm_resource_group" "rg" { + name = "${local.regular_name_prefix}-rg" + location = var.location +} + +# +# Set up an AD User to use as AD Administrator for the Azure SQL server. +# +# Using a regular user account makes it simpler to log in as the user with SSMS or the Go +# driver when setting up the other permissions for the identities that will be tested. +# It appears to although you can make the AD Administrator a service principal, doing so +# leads to issues during logins that do not occur when the AD Administrator is a normal +# AD User account. +# +resource "random_password" "sql_ad_admin_sp_password" { + length = 32 + special = true +} + +resource "azuread_user" "sql_ad_admin" { + user_principal_name = "SQLAdmin.${local.restricted_name_prefix}@${data.azuread_domains.current.domains[0].domain_name}" + display_name = "SQL Admin for ${local.restricted_name_prefix}" + mail_nickname = "SQLAdmin.${local.restricted_name_prefix}" + password = random_password.sql_ad_admin_sp_password.result +} + +# +# Set up the Azure SQL Server +# +# A normal (non-AD) administrator username and password are also provisioned. However, it is +# not possible to create AD users without logging in via an AD-authenticated account, so this +# non-AD administrator is not able to create new AD user accounts. +# +resource "random_password" "sql_admin_password" { + length = 16 + special = true +} + +resource "azurerm_sql_server" "sql_server" { + name = local.restricted_name_prefix + resource_group_name = azurerm_resource_group.rg.name + location = azurerm_resource_group.rg.location + + version = "12.0" + administrator_login = "sql-admin" + administrator_login_password = random_password.sql_admin_password.result +} + +resource "azurerm_sql_active_directory_administrator" "sql_server" { + server_name = azurerm_sql_server.sql_server.name + resource_group_name = azurerm_sql_server.sql_server.resource_group_name + login = "sql-ad-admin" + tenant_id = data.azurerm_client_config.current.tenant_id + object_id = azuread_user.sql_ad_admin.id +} + +resource "azurerm_sql_firewall_rule" "sql_server_allow_azure" { + name = "AllowAzureAccess" + server_name = azurerm_sql_server.sql_server.name + resource_group_name = azurerm_sql_server.sql_server.resource_group_name + start_ip_address = "0.0.0.0" + end_ip_address = "0.0.0.0" +} + +resource "azurerm_sql_firewall_rule" "sql_server_allow_workstation" { + name = "AllowWorkstationAccess" + server_name = azurerm_sql_server.sql_server.name + resource_group_name = azurerm_sql_server.sql_server.resource_group_name + start_ip_address = local.workstation_ip + end_ip_address = local.workstation_ip +} + +# +# Set up the test database on the Azure SQL server +# +resource "azurerm_sql_database" "sql_db" { + name = "go-mssqldb" + + server_name = azurerm_sql_server.sql_server.name + resource_group_name = azurerm_sql_server.sql_server.resource_group_name + location = azurerm_sql_server.sql_server.location + + requested_service_objective_name = "S0" +} + +# +# Create a service principal that will be granted access to the database, +# representing an application login to the database. +# +resource "azuread_application" "app" { + name = "${local.regular_name_prefix}-app" +} + +resource "azuread_service_principal" "app_sp" { + application_id = azuread_application.app.application_id + app_role_assignment_required = false +} + +resource "random_password" "app_sp_password" { + length = 32 + special = true +} + +resource "azuread_service_principal_password" "app_sp" { + service_principal_id = azuread_service_principal.app_sp.id + value = random_password.app_sp_password.result + end_date_relative = "8760h" +} + + +# +# Create a user-assigned identity that we will add to the VM in addition to the +# system-assigned identity. +# +resource "azurerm_user_assigned_identity" "vm_user_id" { + resource_group_name = azurerm_resource_group.rg.name + location = azurerm_resource_group.rg.location + + name = "${local.restricted_name_prefix}-user-id" +} + +# +# Create an Azure VM for testing managed identity authentication. +# +# To support the Azure VM, we need a virtual network, a subnet, the public IP, the network +# security group, and the network interface. The network security group allows incoming SSH +# from the anywhere on the internet. +# +resource "azurerm_virtual_network" "vm_vnet" { + name = "${local.regular_name_prefix}-vnet" + resource_group_name = azurerm_resource_group.rg.name + location = azurerm_resource_group.rg.location + address_space = ["10.0.0.0/16"] +} + +resource "azurerm_subnet" "vm_subnet" { + name = "${local.regular_name_prefix}-vm-sn" + resource_group_name = azurerm_resource_group.rg.name + virtual_network_name = azurerm_virtual_network.vm_vnet.name + address_prefix = "10.0.2.0/24" +} + +resource "azurerm_public_ip" "vm_ip" { + name = "${local.regular_name_prefix}-vm-ip" + resource_group_name = azurerm_resource_group.rg.name + location = azurerm_resource_group.rg.location + allocation_method = "Dynamic" + idle_timeout_in_minutes = 30 +} + +resource "azurerm_network_security_group" "vm_nsg" { + name = "${local.regular_name_prefix}-vm-nsg" + resource_group_name = azurerm_resource_group.rg.name + location = azurerm_resource_group.rg.location + + security_rule { + name = "SSH" + priority = 1001 + direction = "Inbound" + access = "Allow" + protocol = "Tcp" + source_port_range = "*" + destination_port_range = "22" + source_address_prefix = "*" + destination_address_prefix = "*" + } +} + +resource "azurerm_network_interface" "vm_nic" { + name = "${local.regular_name_prefix}-vm-nic" + resource_group_name = azurerm_resource_group.rg.name + location = azurerm_resource_group.rg.location + network_security_group_id = azurerm_network_security_group.vm_nsg.id + + ip_configuration { + name = "${local.regular_name_prefix}-vm-nic-config" + subnet_id = azurerm_subnet.vm_subnet.id + private_ip_address_allocation = "Dynamic" + public_ip_address_id = azurerm_public_ip.vm_ip.id + } +} + +# +# Given the networking setup, now create the Azure VM +# +resource "azurerm_virtual_machine" "vm" { + name = "${local.regular_name_prefix}-vm" + resource_group_name = azurerm_resource_group.rg.name + location = azurerm_resource_group.rg.location + network_interface_ids = [azurerm_network_interface.vm_nic.id] + vm_size = "Standard_B1s" + + storage_os_disk { + name = "${local.regular_name_prefix}-vm-os" + caching = "ReadWrite" + create_option = "FromImage" + managed_disk_type = "Standard_LRS" + } + + storage_image_reference { + publisher = "Canonical" + offer = "UbuntuServer" + sku = "18.04-LTS" + version = "latest" + } + + os_profile { + computer_name = "${local.regular_name_prefix}-vm" + admin_username = var.vm_admin_name + } + + os_profile_linux_config { + disable_password_authentication = true + ssh_keys { + path = "/home/${var.vm_admin_name}/.ssh/authorized_keys" + key_data = local.public_key_openssh + } + } + + # Configure the VM with both SystemAssigned and a UserAssigned identity + identity { + type = "SystemAssigned, UserAssigned" + identity_ids = [azurerm_user_assigned_identity.vm_user_id.id] + } +} + +# Retrieve the application ID corresponding to the service principal ID assigned to the VM. +data "azuread_service_principal" "vm_sp" { + object_id = azurerm_virtual_machine.vm.identity.0.principal_id +} + +# Wait for public IP to be assigned after VM is created so we can report it in the outputs. +data "azurerm_public_ip" "vm_ip" { + name = azurerm_public_ip.vm_ip.name + resource_group_name = azurerm_virtual_machine.vm.resource_group_name +} + +# +# After provisioning or refreshing, Terraform will populate these outputs. +# These capture the necessary pieces of information to access the new infrastructure. +# +output "tenant_id" { + description = "Azure tenant ID" + value = data.azurerm_client_config.current.tenant_id +} + +output "subscription_id" { + description = "Azure subscription ID" + value = data.azurerm_client_config.current.subscription_id +} + +output "sql_server_name" { + description = "Azure SQL server name" + value = azurerm_sql_server.sql_server.name +} + +output "sql_server_fqdn" { + description = "Azure SQL server domain name" + value = azurerm_sql_server.sql_server.fully_qualified_domain_name +} + +output "sql_ad_admin_user" { + description = "Azure SQL administrator name (AD authentication)" + value = azuread_user.sql_ad_admin.user_principal_name +} + +output "sql_ad_admin_password" { + description = "Azure SQL administrator password (AD authentication)" + value = random_password.sql_ad_admin_sp_password.result + sensitive = true +} + +output "sql_admin_user" { + description = "Azure SQL administrator name (SQL server authentication)" + value = azurerm_sql_server.sql_server.administrator_login +} + +output "sql_admin_password" { + description = "Azure SQL administrator password (SQL server authentication)" + value = random_password.sql_admin_password.result + sensitive = true +} + +output "sql_database_name" { + description = "Azure SQL database name" + value = azurerm_sql_database.sql_db.name +} + +output "vm_name" { + description = "Azure virtual machine name" + value = azurerm_virtual_machine.vm.name +} + +output "vm_client_id" { + description = "Azure VM system-assigned identity client ID" + value = data.azuread_service_principal.vm_sp.application_id +} + +output "vm_principal_id" { + description = "Azure VM system-assigned identity principal ID" + value = azurerm_virtual_machine.vm.identity.0.principal_id +} + +output "vm_ip_address" { + description = "Azure virtual machine public IP" + value = data.azurerm_public_ip.vm_ip.ip_address +} + +output "vm_admin_name" { + description = "Azure virtual machine admin user name" + value = var.vm_admin_name +} + +output "vm_user_ssh_private_key" { + description = "Azure virtual machine admin user private SSH key" + value = local.private_key_pem + sensitive = true +} + +output "vm_user_ssh_openssh_key" { + description = "Azure virtual machine admin user SSH public key" + value = local.public_key_openssh + sensitive = true +} + +output "app_sp_client_id" { + description = "Service principal client ID for application user" + value = azuread_application.app.application_id +} + +output "app_name" { + description = "Service principal name for application user" + value = azuread_application.app.name +} + +output "app_sp_client_secret" { + description = "Service principal client secret for application user" + value = random_password.app_sp_password.result + sensitive = true +} + +output "user_assigned_identity_name" { + description = "User-assigned identity for the Azure virtual machine" + value = azurerm_user_assigned_identity.vm_user_id.name +} + +output "user_assigned_identity_client_id" { + description = "User-assigned identity client ID" + value = azurerm_user_assigned_identity.vm_user_id.client_id +} diff --git a/examples/simple/simple.go b/examples/simple/simple.go index 67f88aa4..a2590502 100644 --- a/examples/simple/simple.go +++ b/examples/simple/simple.go @@ -5,12 +5,16 @@ import ( "flag" "fmt" "log" + "net/url" + "os" - _ "github.com/denisenkom/go-mssqldb" + _ "github.com/denisenkom/go-mssqldb/azuread" ) var ( + database = flag.String("database", "", "the database name") debug = flag.Bool("debug", false, "enable debugging") + dsn = flag.String("dsn", os.Getenv("SQLSERVER_DSN"), "complete SQL DSN") password = flag.String("password", "", "the database password") port *int = flag.Int("port", 1433, "the database port") server = flag.String("server", "", "the database server") @@ -20,24 +24,35 @@ var ( func main() { flag.Parse() - if *debug { - fmt.Printf(" password:%s\n", *password) - fmt.Printf(" port:%d\n", *port) - fmt.Printf(" server:%s\n", *server) - fmt.Printf(" user:%s\n", *user) + var connString string + + if *dsn == "" { + if *debug { + fmt.Printf(" server: %s\n", *server) + fmt.Printf(" port: %d\n", *port) + fmt.Printf(" user: %s\n", *user) + fmt.Printf(" password: %s\n", *password) + fmt.Printf(" database: %s\n", *database) + } + + connString = fmt.Sprintf("sqlserver://%s:%s@%s:%d?database=%s&encrypt=true", + url.QueryEscape(*user), url.QueryEscape(*password), + url.QueryEscape(*server), *port, url.QueryEscape(*database)) + } else { + connString = *dsn } - connString := fmt.Sprintf("server=%s;user id=%s;password=%s;port=%d", *server, *user, *password, *port) if *debug { - fmt.Printf(" connString:%s\n", connString) + fmt.Printf(" dsn: %s\n", connString) } + conn, err := sql.Open("mssql", connString) if err != nil { log.Fatal("Open connection failed:", err.Error()) } defer conn.Close() - stmt, err := conn.Prepare("select 1, 'abc'") + stmt, err := conn.Prepare("select 1, 'abc', suser_name()") if err != nil { log.Fatal("Prepare failed:", err.Error()) } @@ -46,12 +61,14 @@ func main() { row := stmt.QueryRow() var somenumber int64 var somechars string - err = row.Scan(&somenumber, &somechars) + var someuser string + err = row.Scan(&somenumber, &somechars, &someuser) if err != nil { log.Fatal("Scan failed:", err.Error()) } - fmt.Printf("somenumber:%d\n", somenumber) - fmt.Printf("somechars:%s\n", somechars) + fmt.Printf("number: %d\n", somenumber) + fmt.Printf("chars: %s\n", somechars) + fmt.Printf("user: %s\n", someuser) fmt.Printf("bye\n") } diff --git a/examples/tvp/tvp.go b/examples/tvp/tvp.go index a07bb652..eae614ef 100644 --- a/examples/tvp/tvp.go +++ b/examples/tvp/tvp.go @@ -1,3 +1,5 @@ +// +build go1.9 + package main import ( diff --git a/fedauth.go b/fedauth.go new file mode 100644 index 00000000..45bd85c1 --- /dev/null +++ b/fedauth.go @@ -0,0 +1,103 @@ +package mssql + +import ( + "crypto/rsa" + "crypto/x509" + "errors" + "fmt" + "io/ioutil" + + "golang.org/x/crypto/pkcs12" +) + +// FederatedAuthenticationConfigurer implementations use the connection string +// parameters to create a token provider that can be used to obtain tokens +// during the login sequence. +type FederatedAuthenticationConfigurer interface { + SecurityTokenProviderFromCertificate(clientID, tenantID string, cert *x509.Certificate, key *rsa.PrivateKey) SecurityTokenProvider + SecurityTokenProviderFromSecret(clientID, tenantID, clientSecret string) SecurityTokenProvider + ActiveDirectoryTokenProviderFromPassword(user, password string) ActiveDirectoryTokenProvider + ActiveDirectoryTokenProviderFromIdentity(clientID string) ActiveDirectoryTokenProvider +} + +// SetFederatedAuthenticationConfigurer injects an implementation to use. +func SetFederatedAuthenticationConfigurer(library FederatedAuthenticationConfigurer) { + fedAuthConfigurer = library +} + +var fedAuthConfigurer FederatedAuthenticationConfigurer + +func (c *Connector) checkFedAuthProviders(params *connectParams) error { + // If there is a user-specified SecurityTokenProvider, use that in preference + // to DSN or ActiveDirectoryTokenProvider. + if c != nil && c.SecurityTokenProvider != nil { + params.fedAuthLibrary = fedAuthLibrarySecurityToken + params.securityTokenProvider = c.SecurityTokenProvider + + return nil + } + + // Likewise if there is an existing user-specified ActiveDirectoryTokenProvider. + if c != nil && c.ActiveDirectoryTokenProvider != nil { + params.fedAuthLibrary = fedAuthLibraryADAL + params.activeDirectoryTokenProvider = c.ActiveDirectoryTokenProvider + + return nil + } + + // Ignore DSNs that don't request one of the supported federated authentication + // libraries. + if params.fedAuthLibrary != fedAuthLibrarySecurityToken && params.fedAuthLibrary != fedAuthLibraryADAL { + return nil + } + + if fedAuthConfigurer == nil { + return errors.New("No federated authentication library available: inject using SetFederatedAuthenticationConfigurer") + } + + switch { + case params.fedAuthLibrary == fedAuthLibrarySecurityToken && params.aadClientCertPath != "": + certificate, rsaPrivateKey, err := getFedAuthClientCertificate(params.aadClientCertPath, params.password) + if err != nil { + return err + } + + params.securityTokenProvider = fedAuthConfigurer.SecurityTokenProviderFromCertificate(params.user, params.aadTenantID, certificate, rsaPrivateKey) + + case params.fedAuthLibrary == fedAuthLibrarySecurityToken: + params.securityTokenProvider = fedAuthConfigurer.SecurityTokenProviderFromSecret(params.user, params.aadTenantID, params.password) + + case params.fedAuthLibrary == fedAuthLibraryADAL && params.fedAuthADALWorkflow == fedAuthADALWorkflowPassword: + params.activeDirectoryTokenProvider = fedAuthConfigurer.ActiveDirectoryTokenProviderFromPassword(params.user, params.password) + + case params.fedAuthLibrary == fedAuthLibraryADAL && params.fedAuthADALWorkflow == fedAuthADALWorkflowMSI: + params.activeDirectoryTokenProvider = fedAuthConfigurer.ActiveDirectoryTokenProviderFromIdentity(params.user) + + case params.fedAuthLibrary == fedAuthLibraryADAL: + return errors.New("Unsupported ADAL workflow") + + default: + return errors.New("Unsupported federated authentication library") + } + + return nil +} + +func getFedAuthClientCertificate(clientCertPath, clientCertPassword string) (*x509.Certificate, *rsa.PrivateKey, error) { + pkcs, err := ioutil.ReadFile(clientCertPath) + if err != nil { + return nil, nil, fmt.Errorf("Failed to read the AD client certificate from path %s: %v", clientCertPath, err) + } + + privateKey, certificate, err := pkcs12.Decode(pkcs, clientCertPassword) + if err != nil { + return nil, nil, fmt.Errorf("Failed to read the AD client certificate from path %s: %v", clientCertPath, err) + } + + rsaPrivateKey, isRsaKey := privateKey.(*rsa.PrivateKey) + if !isRsaKey { + return nil, nil, fmt.Errorf("AD client certificate at path %s must contain an RSA private key", clientCertPath) + } + + return certificate, rsaPrivateKey, nil +} diff --git a/go.mod b/go.mod index ebc02ab8..2526b04a 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,8 @@ module github.com/denisenkom/go-mssqldb go 1.11 require ( + github.com/Azure/go-autorest/autorest/adal v0.8.0 github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c + golang.org/x/tools v0.0.0-20191206204035-259af5ff87bd // indirect ) diff --git a/go.sum b/go.sum index 1887801b..d461dea2 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,31 @@ +github.com/Azure/go-autorest v13.3.0+incompatible h1:8Ix0VdeOllBx9jEcZ2Wb1uqWUpE1awmJiaHztwaJCPk= +github.com/Azure/go-autorest/autorest v0.9.0 h1:MRvx8gncNaXJqOoLmhNjUAKh33JJF8LyxPhomEtOsjs= +github.com/Azure/go-autorest/autorest v0.9.0/go.mod h1:xyHB1BMZT0cuDHU7I0+g046+BFDTQ8rEZB0s4Yfa6bI= +github.com/Azure/go-autorest/autorest v0.9.2 h1:6AWuh3uWrsZJcNoCHrCF/+g4aKPCU39kaMO6/qrnK/4= +github.com/Azure/go-autorest/autorest/adal v0.5.0/go.mod h1:8Z9fGy2MpX0PvDjB1pEgQTmVqjGhiHBW7RJJEciWzS0= +github.com/Azure/go-autorest/autorest/adal v0.8.0 h1:CxTzQrySOxDnKpLjFJeZAS5Qrv/qFPkgLjx5bOAi//I= +github.com/Azure/go-autorest/autorest/adal v0.8.0/go.mod h1:Z6vX6WXXuyieHAXwMj0S6HY6e6wcHn37qQMBQlvY3lc= +github.com/Azure/go-autorest/autorest/date v0.1.0/go.mod h1:plvfp3oPSKwf2DNjlBjWF/7vwR+cUD/ELuzDCXwHUVA= +github.com/Azure/go-autorest/autorest/date v0.2.0 h1:yW+Zlqf26583pE43KhfnhFcdmSWlm5Ew6bxipnr/tbM= +github.com/Azure/go-autorest/autorest/date v0.2.0/go.mod h1:vcORJHLJEh643/Ioh9+vPmf1Ij9AEBM5FuBIXLmIy0g= +github.com/Azure/go-autorest/autorest/mocks v0.1.0/go.mod h1:OTyCOPRA2IgIlWxVYxBee2F5Gr4kF2zd2J5cFRaIDN0= +github.com/Azure/go-autorest/autorest/mocks v0.2.0/go.mod h1:OTyCOPRA2IgIlWxVYxBee2F5Gr4kF2zd2J5cFRaIDN0= +github.com/Azure/go-autorest/autorest/mocks v0.3.0 h1:qJumjCaCudz+OcqE9/XtEPfvtOjOmKaui4EOpFI6zZc= +github.com/Azure/go-autorest/autorest/mocks v0.3.0/go.mod h1:a8FDP3DYzQ4RYfVAxAN3SVSiiO77gL2j2ronKKP0syM= +github.com/Azure/go-autorest/logger v0.1.0/go.mod h1:oExouG+K6PryycPJfVSxi/koC6LSNgds39diKLz7Vrc= +github.com/Azure/go-autorest/tracing v0.5.0 h1:TRn4WjSnkcSy5AEG3pnbtFSwNtwzjr4VYyQflFE619k= +github.com/Azure/go-autorest/tracing v0.5.0/go.mod h1:r/s2XiOKccPW3HrqB+W0TQzfbtp2fGCgRFtBroKn4Dk= +github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM= +github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c h1:Vj5n4GlwjmQteupaxJ9+0FNOmBrHfq7vN4btdGoDZgI= golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/tools v0.0.0-20191206204035-259af5ff87bd h1:Zc7EU2PqpsNeIfOoVA7hvQX4cS3YDJEs5KlfatT3hLo= +golang.org/x/tools v0.0.0-20191206204035-259af5ff87bd/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/log_conn.go b/log_conn.go new file mode 100644 index 00000000..4777e4c1 --- /dev/null +++ b/log_conn.go @@ -0,0 +1,80 @@ +package mssql + +import ( + "encoding/hex" + "net" + "strings" + "time" +) + +type connLogger struct { + conn net.Conn + readKind, writeKind string + readCount, writeCount int + logger Logger +} + +var _ net.Conn = &connLogger{} + +func newConnLogger(conn net.Conn, kind string, logger Logger) net.Conn { + if len(kind) > 0 && !strings.HasPrefix(kind, " ") { + kind = " " + kind + } + + cl := &connLogger{ + conn: conn, + readKind: "R" + kind, + writeKind: "W" + kind, + logger: logger, + } + + return cl +} + +func (cl *connLogger) Read(p []byte) (n int, err error) { + n, err = cl.conn.Read(p) + + if n > 0 { + dump := hex.Dump(p) + cl.logger.Printf("%s %d\n%s", cl.readKind, cl.readCount, dump) + cl.readCount += n + } + + return +} + +func (cl *connLogger) Write(p []byte) (n int, err error) { + n, err = cl.conn.Write(p) + + if n > 0 { + dump := hex.Dump(p) + cl.logger.Printf("%s %d\n%s", cl.writeKind, cl.writeCount, dump) + cl.writeCount += n + } + + return +} + +func (cl *connLogger) Close() (err error) { + return cl.conn.Close() +} + +func (cl *connLogger) LocalAddr() net.Addr { + return cl.conn.LocalAddr() +} + +func (cl *connLogger) RemoteAddr() net.Addr { + return cl.conn.RemoteAddr() +} + +func (cl *connLogger) SetDeadline(t time.Time) error { + return cl.conn.SetDeadline(t) +} + +func (cl *connLogger) SetReadDeadline(t time.Time) error { + return cl.conn.SetReadDeadline(t) +} + +func (cl *connLogger) SetWriteDeadline(t time.Time) error { + return cl.conn.SetWriteDeadline(t) +} diff --git a/mssql.go b/mssql.go index 5d815169..0ca04999 100644 --- a/mssql.go +++ b/mssql.go @@ -58,6 +58,7 @@ func (d *Driver) OpenConnector(dsn string) (*Connector, error) { if err != nil { return nil, err } + return &Connector{ params: params, driver: d, @@ -91,6 +92,21 @@ func NewConnector(dsn string) (*Connector, error) { return c, nil } +// SecurityTokenProvider implementations are called during federated +// authentication security token login sequences at the point when the +// security token is required. The string returned should be the access +// token to supply to the server, otherwise an error can be returned to +// indicate why a token is not available. +type SecurityTokenProvider func(ctx context.Context) (string, error) + +// ActiveDirectoryTokenProvider implementations are called during federated +// authentication login sequences where the server provides a service +// principal name and security token service endpoint that should be used +// to obtain the token. Implementations should contact the security token +// service specified and obtain the appropriate token, or return an error +// to indicate why a token is not available. +type ActiveDirectoryTokenProvider func(ctx context.Context, serverSPN, stsURL string) (string, error) + // Connector holds the parsed DSN and is ready to make a new connection // at any time. // @@ -126,6 +142,16 @@ type Connector struct { // Dialer sets a custom dialer for all network operations. // If Dialer is not set, normal net dialers are used. Dialer Dialer + + // SecurityTokenProvider sets the implementation that will obtain an + // authentication token when using federated authentication with the + // security token library. + SecurityTokenProvider SecurityTokenProvider + + // ActiveDirectoryTokenProvider sets the implementation that will + // obtain a authentication token when using federated authentication + // with the ActiveDirectory library. + ActiveDirectoryTokenProvider ActiveDirectoryTokenProvider } type Dialer interface { diff --git a/tds.go b/tds.go index 832c4fd2..55f344b5 100644 --- a/tds.go +++ b/tds.go @@ -10,6 +10,7 @@ import ( "io" "io/ioutil" "net" + "os" "sort" "strconv" "strings" @@ -89,12 +90,13 @@ const ( // 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx packAttention = 6 - packBulkLoadBCP = 7 - packTransMgrReq = 14 - packNormal = 15 - packLogin7 = 16 - packSSPIMessage = 17 - packPrelogin = 18 + packBulkLoadBCP = 7 + packFedAuthToken = 8 + packTransMgrReq = 14 + packNormal = 15 + packLogin7 = 16 + packSSPIMessage = 17 + packPrelogin = 18 ) // prelogin fields @@ -118,6 +120,33 @@ const ( encryptReq = 3 // Encryption is required. ) +const ( + featExtSESSIONRECOVERY byte = 0x01 + featExtFEDAUTH byte = 0x02 + featExtCOLUMNENCRYPTION byte = 0x04 + featExtGLOBALTRANSACTIONS byte = 0x05 + featExtAZURESQLSUPPORT byte = 0x08 + featExtDATACLASSIFICATION byte = 0x09 + featExtUTF8SUPPORT byte = 0x0A + featExtTERMINATOR byte = 0xFF +) + +// Federated authentication library affects the login data structure and message sequence. +const ( + fedAuthLibraryLiveIDCompactToken = 0x00 + fedAuthLibrarySecurityToken = 0x01 + fedAuthLibraryADAL = 0x02 + + fedAuthLibraryReserved = 0x7F +) + +// Federated authentication ADAL workflow affects the mechanism used to authenticate. +const ( + fedAuthADALWorkflowPassword = 0x01 + fedAuthADALWorkflowIntegrated = 0x02 + fedAuthADALWorkflowMSI = 0x03 +) + type tdsSession struct { buf *tdsBuffer loginAck loginAckStruct @@ -139,6 +168,7 @@ const ( logParams = 16 logTransaction = 32 logDebug = 64 + logTraffic = 128 ) type columnStruct struct { @@ -240,6 +270,16 @@ const ( fIntSecurity = 0x80 ) +// OptionFlags3 +// http://msdn.microsoft.com/en-us/library/dd304019.aspx +const ( + fChangePassword = 1 + fSendYukonBinaryXML = 2 + fUserInstance = 4 + fUnknownCollationHandling = 8 + fExtension = 0x10 +) + // TypeFlags const ( // 4 bits for fSQLType @@ -247,119 +287,35 @@ const ( fReadOnlyIntent = 32 ) -// OptionFlags3 -// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/773a62b6-ee89-4c02-9e5e-344882630aac -const ( - fExtension = 0x10 -) - type login struct { - TDSVersion uint32 - PacketSize uint32 - ClientProgVer uint32 - ClientPID uint32 - ConnectionID uint32 - OptionFlags1 uint8 - OptionFlags2 uint8 - TypeFlags uint8 - OptionFlags3 uint8 - ClientTimeZone int32 - ClientLCID uint32 - HostName string - UserName string - Password string - AppName string - ServerName string - CtlIntName string - Language string - Database string - ClientID [6]byte - SSPI []byte - AtchDBFile string - ChangePassword string - FeatureExt featureExts -} - -type featureExts struct { - features map[byte]featureExt -} - -type featureExt interface { - featureID() byte - toBytes() []byte -} - -func (e *featureExts) Add(f featureExt) error { - if f == nil { - return nil - } - id := f.featureID() - if _, exists := e.features[id]; exists { - f := "Login error: Feature with ID '%v' is already present in FeatureExt block." - return fmt.Errorf(f, id) - } - if e.features == nil { - e.features = make(map[byte]featureExt) - } - e.features[id] = f - return nil -} - -func (e featureExts) toBytes() []byte { - if len(e.features) == 0 { - return nil - } - var d []byte - for featureID, f := range e.features { - featureData := f.toBytes() - - hdr := make([]byte, 5) - hdr[0] = featureID // FedAuth feature extension BYTE - binary.LittleEndian.PutUint32(hdr[1:], uint32(len(featureData))) // FeatureDataLen DWORD - d = append(d, hdr...) - - d = append(d, featureData...) // FeatureData *BYTE - } - if d != nil { - d = append(d, 0xff) // Terminator - } - return d -} - -type featureExtFedAuthSTS struct { - FedAuthEcho bool - FedAuthToken string - Nonce []byte -} - -func (e *featureExtFedAuthSTS) featureID() byte { - return 0x02 -} - -func (e *featureExtFedAuthSTS) toBytes() []byte { - if e == nil { - return nil - } - - options := byte(0x01) << 1 // 0x01 => STS bFedAuthLibrary 7BIT - if e.FedAuthEcho { - options |= 1 // fFedAuthEcho - } - - d := make([]byte, 5) - d[0] = options - - // looks like string in - // https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/f88b63bb-b479-49e1-a87b-deda521da508 - tokenBytes := str2ucs2(e.FedAuthToken) - binary.LittleEndian.PutUint32(d[1:], uint32(len(tokenBytes))) // Should be a signed int32, but since the length is relatively small, this should work - d = append(d, tokenBytes...) - - if len(e.Nonce) == 32 { - d = append(d, e.Nonce...) - } - - return d + TDSVersion uint32 + PacketSize uint32 + ClientProgVer uint32 + ClientPID uint32 + ConnectionID uint32 + OptionFlags1 uint8 + OptionFlags2 uint8 + TypeFlags uint8 + OptionFlags3 uint8 + ClientTimeZone int32 + ClientLCID uint32 + HostName string + UserName string + Password string + AppName string + ServerName string + CtlIntName string + Language string + Database string + ClientID [6]byte + SSPI []byte + AtchDBFile string + ChangePassword string + FedAuthLibrary byte + FedAuthEcho byte + FedAuthToken string + FedAuthNonce []byte + FedAuthADALWorkflow byte } type loginHeader struct { @@ -448,44 +404,84 @@ func sendLogin(w *tdsBuffer, login login) error { database := str2ucs2(login.Database) atchdbfile := str2ucs2(login.AtchDBFile) changepassword := str2ucs2(login.ChangePassword) - featureExt := login.FeatureExt.toBytes() + fedauthtoken := str2ucs2(login.FedAuthToken) + + // Determine if any feature extensions need to be written so we know whether + // to include the option flag and offset to the data. + var featureExtLength uint32 + fedAuth := true + switch login.FedAuthLibrary { + case fedAuthLibrarySecurityToken: + // Each feature extension record is 1 byte to indicate the type of the + // feature extension, four bytes for the data size, then the data size. + // For the SecurityToken data, the size is one byte to indicate that + // it's the SecurityToken library, four bytes for the token length, + // then the token bytes. + featureExtLength += uint32(1 + 4 + 1 + 4 + len(fedauthtoken)) + case fedAuthLibraryADAL: + // In addition to the 1 + 4 bytes for the feature extension header, + // ADAL just requires one byte to indicate the library and one to + // set the workflow (password, integrated or MSI). + featureExtLength += uint32(1 + 4 + 1 + 1) + default: + fedAuth = false + } + + // If any feature extension records are written, a final single-byte terminator + // record must also be written, and the fExtension flag must be set. + if featureExtLength > 0 { + featureExtLength++ + login.OptionFlags3 |= fExtension + } else { + login.OptionFlags3 &^= fExtension + } hdr := loginHeader{ - TDSVersion: login.TDSVersion, - PacketSize: login.PacketSize, - ClientProgVer: login.ClientProgVer, - ClientPID: login.ClientPID, - ConnectionID: login.ConnectionID, - OptionFlags1: login.OptionFlags1, - OptionFlags2: login.OptionFlags2, - TypeFlags: login.TypeFlags, - OptionFlags3: login.OptionFlags3, - ClientTimeZone: login.ClientTimeZone, - ClientLCID: login.ClientLCID, - HostNameLength: uint16(utf8.RuneCountInString(login.HostName)), - UserNameLength: uint16(utf8.RuneCountInString(login.UserName)), - PasswordLength: uint16(utf8.RuneCountInString(login.Password)), - AppNameLength: uint16(utf8.RuneCountInString(login.AppName)), - ServerNameLength: uint16(utf8.RuneCountInString(login.ServerName)), - CtlIntNameLength: uint16(utf8.RuneCountInString(login.CtlIntName)), - LanguageLength: uint16(utf8.RuneCountInString(login.Language)), - DatabaseLength: uint16(utf8.RuneCountInString(login.Database)), - ClientID: login.ClientID, - SSPILength: uint16(len(login.SSPI)), - AtchDBFileLength: uint16(utf8.RuneCountInString(login.AtchDBFile)), - ChangePasswordLength: uint16(utf8.RuneCountInString(login.ChangePassword)), + TDSVersion: login.TDSVersion, + PacketSize: login.PacketSize, + ClientProgVer: login.ClientProgVer, + ClientPID: login.ClientPID, + ConnectionID: login.ConnectionID, + OptionFlags1: login.OptionFlags1, + OptionFlags2: login.OptionFlags2, + TypeFlags: login.TypeFlags, + OptionFlags3: login.OptionFlags3, + ClientTimeZone: login.ClientTimeZone, + ClientLCID: login.ClientLCID, + HostNameLength: uint16(utf8.RuneCountInString(login.HostName)), + AppNameLength: uint16(utf8.RuneCountInString(login.AppName)), + ServerNameLength: uint16(utf8.RuneCountInString(login.ServerName)), + CtlIntNameLength: uint16(utf8.RuneCountInString(login.CtlIntName)), + LanguageLength: uint16(utf8.RuneCountInString(login.Language)), + DatabaseLength: uint16(utf8.RuneCountInString(login.Database)), + ClientID: login.ClientID, + SSPILength: uint16(len(login.SSPI)), + AtchDBFileLength: uint16(utf8.RuneCountInString(login.AtchDBFile)), } offset := uint16(binary.Size(hdr)) hdr.HostNameOffset = offset offset += uint16(len(hostname)) - hdr.UserNameOffset = offset - offset += uint16(len(username)) - hdr.PasswordOffset = offset - offset += uint16(len(password)) + + if !fedAuth { + hdr.UserNameOffset = offset + hdr.UserNameLength = uint16(utf8.RuneCountInString(login.UserName)) + offset += uint16(len(username)) + hdr.PasswordOffset = offset + hdr.PasswordLength = uint16(utf8.RuneCountInString(login.Password)) + offset += uint16(len(password)) + } + hdr.AppNameOffset = offset offset += uint16(len(appname)) hdr.ServerNameOffset = offset offset += uint16(len(servername)) + + if featureExtLength > 0 { + hdr.ExtensionOffset = offset + hdr.ExtensionLength = 4 + offset += hdr.ExtensionLength + } + hdr.CtlIntNameOffset = offset offset += uint16(len(ctlintname)) hdr.LanguageOffset = offset @@ -496,20 +492,14 @@ func sendLogin(w *tdsBuffer, login login) error { offset += uint16(len(login.SSPI)) hdr.AtchDBFileOffset = offset offset += uint16(len(atchdbfile)) - hdr.ChangePasswordOffset = offset - offset += uint16(len(changepassword)) - featureExtOffset := uint32(0) - featureExtLen := len(featureExt) - if featureExtLen > 0 { - hdr.OptionFlags3 |= fExtension - hdr.ExtensionOffset = offset - hdr.ExtensionLength = 4 - offset += hdr.ExtensionLength // DWORD - featureExtOffset = uint32(offset) + if !fedAuth { + hdr.ChangePasswordOffset = offset + hdr.ChangePasswordLength = uint16(utf8.RuneCountInString(login.ChangePassword)) + offset += uint16(len(changepassword)) } - hdr.Length = uint32(offset) + uint32(featureExtLen) + hdr.Length = uint32(offset) + featureExtLength var err error err = binary.Write(w, binary.LittleEndian, &hdr) if err != nil { @@ -519,13 +509,16 @@ func sendLogin(w *tdsBuffer, login login) error { if err != nil { return err } - _, err = w.Write(username) - if err != nil { - return err - } - _, err = w.Write(password) - if err != nil { - return err + + if !fedAuth { + _, err = w.Write(username) + if err != nil { + return err + } + _, err = w.Write(password) + if err != nil { + return err + } } _, err = w.Write(appname) if err != nil { @@ -535,6 +528,12 @@ func sendLogin(w *tdsBuffer, login login) error { if err != nil { return err } + if featureExtLength > 0 { + err = binary.Write(w, binary.LittleEndian, uint32(offset)) + if err != nil { + return err + } + } _, err = w.Write(ctlintname) if err != nil { return err @@ -555,20 +554,63 @@ func sendLogin(w *tdsBuffer, login login) error { if err != nil { return err } - _, err = w.Write(changepassword) - if err != nil { - return err - } - if featureExtOffset > 0 { - err = binary.Write(w, binary.LittleEndian, featureExtOffset) - if err != nil { - return err - } - _, err = w.Write(featureExt) + + if !fedAuth { + _, err = w.Write(changepassword) if err != nil { return err } } + + // Write the feature extension record for federated authentication, if in use + switch login.FedAuthLibrary { + case fedAuthLibrarySecurityToken: + w.WriteByte(featExtFEDAUTH) + binary.Write(w, binary.LittleEndian, uint32(1+4+len(fedauthtoken))) + w.WriteByte(login.FedAuthLibrary<<1 | login.FedAuthEcho) + binary.Write(w, binary.LittleEndian, uint32(len(fedauthtoken))) + w.Write(fedauthtoken) + case fedAuthLibraryADAL: + w.WriteByte(featExtFEDAUTH) + binary.Write(w, binary.LittleEndian, uint32(1+1)) + w.WriteByte(login.FedAuthLibrary<<1 | login.FedAuthEcho) + w.WriteByte(login.FedAuthADALWorkflow) + } + + // Write the feature extension terminator if any feature extensions are written + if featureExtLength > 0 { + w.WriteByte(featExtTERMINATOR) + } + return w.FinishPacket() +} + +// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/827d9632-2957-4d54-b9ea-384530ae79d0 +func sendFedAuthInfo(w *tdsBuffer, login login) (err error) { + fedauthtoken := str2ucs2(login.FedAuthToken) + tokenlen := len(fedauthtoken) + datalen := 4 + tokenlen + len(login.FedAuthNonce) + + w.BeginPacket(packFedAuthToken, false) + err = binary.Write(w, binary.LittleEndian, uint32(datalen)) + if err != nil { + return + } + + err = binary.Write(w, binary.LittleEndian, uint32(tokenlen)) + if err != nil { + return + } + + _, err = w.Write(fedauthtoken) + if err != nil { + return + } + + _, err = w.Write(login.FedAuthNonce) + if err != nil { + return + } + return w.FinishPacket() } @@ -831,6 +873,11 @@ func dialConnection(ctx context.Context, c *Connector, p connectParams) (conn ne } func connect(ctx context.Context, c *Connector, log optionalLogger, p connectParams) (res *tdsSession, err error) { + err = c.checkFedAuthProviders(&p) + if err != nil { + return nil, err + } + dialCtx := ctx if p.dial_timeout > 0 { var cancel func() @@ -865,6 +912,10 @@ initiate_connection: return nil, err } + if p.logFlags&logTraffic != 0 { + conn = newConnLogger(conn, "TCP", log) + } + toconn := newTimeoutConn(conn, p.conn_timeout) outbuf := newTdsBuffer(p.packetSize, toconn) @@ -892,6 +943,10 @@ initiate_connection: preloginMARS: {0}, // MARS disabled } + if p.fedAuthLibrary != fedAuthLibraryReserved { + fields[preloginFEDAUTHREQUIRED] = []byte{1} + } + err = writePrelogin(outbuf, fields) if err != nil { return nil, err @@ -902,6 +957,21 @@ initiate_connection: return nil, err } + // If the server returns the preloginFEDAUTHREQUIRED field, then federated authentication + // is supported. The actual value may be 0 or 1, where 0 means either SSPI or federated + // authentication is allowed, while 1 means only federated authentication is allowed. + var fedAuthEcho byte + if fedAuthSupport, ok := fields[preloginFEDAUTHREQUIRED]; ok { + if len(fedAuthSupport) != 1 { + return nil, fmt.Errorf("Federated authentication flag length should be 1: is %d", len(fedAuthSupport)) + } + + // We need to be able to echo the value back to the server + fedAuthEcho = fedAuthSupport[0] + } else if p.fedAuthLibrary != fedAuthLibraryReserved { + return nil, fmt.Errorf("Federated authentication is not supported by the server") + } + encryptBytes, ok := fields[preloginENCRYPTION] if !ok { return nil, fmt.Errorf("Encrypt negotiation failed") @@ -925,6 +995,14 @@ initiate_connection: if p.trustServerCertificate { config.InsecureSkipVerify = true } + if p.tlsKeyLogFile != "" { + if w, err := os.OpenFile(p.tlsKeyLogFile, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0600); err == nil { + defer w.Close() + config.KeyLogWriter = w + } else { + return nil, fmt.Errorf("Cannot open TLS key log file %s: %v", p.tlsKeyLogFile, err) + } + } config.ServerName = p.hostInCertificate // fix for https://github.com/denisenkom/go-mssqldb/issues/166 // Go implementation of TLS payload size heuristic algorithm splits single TDS package to multiple TCP segments, @@ -937,7 +1015,11 @@ initiate_connection: tlsConn := tls.Client(&passthrough, &config) err = tlsConn.Handshake() passthrough.c = toconn - outbuf.transport = tlsConn + if sess.logFlags&logTraffic != 0 { + outbuf.transport = newConnLogger(tlsConn, "TLS", log) + } else { + outbuf.transport = tlsConn + } if err != nil { return nil, fmt.Errorf("TLS Handshake failed: %v", err) } @@ -949,25 +1031,52 @@ initiate_connection: } login := login{ - TDSVersion: verTDS74, - PacketSize: uint32(outbuf.PackageSize()), - Database: p.database, - OptionFlags2: fODBC, // to get unlimited TEXTSIZE - HostName: p.workstation, - ServerName: p.host, - AppName: p.appname, - TypeFlags: p.typeFlags, - } - auth, authOk := getAuth(p.user, p.password, p.serverSPN, p.workstation) + TDSVersion: verTDS74, + PacketSize: uint32(outbuf.PackageSize()), + Database: p.database, + OptionFlags2: fODBC, // to get unlimited TEXTSIZE + HostName: p.workstation, + ServerName: p.host, + AppName: p.appname, + TypeFlags: p.typeFlags, + FedAuthLibrary: p.fedAuthLibrary, + FedAuthEcho: fedAuthEcho, + FedAuthADALWorkflow: p.fedAuthADALWorkflow, + } + auth, auth_ok := getAuth(p.user, p.password, p.serverSPN, p.workstation) switch { - case p.fedAuthAccessToken != "": // accesstoken ignores user/password - featurext := &featureExtFedAuthSTS{ - FedAuthEcho: len(fields[preloginFEDAUTHREQUIRED]) > 0 && fields[preloginFEDAUTHREQUIRED][0] == 1, - FedAuthToken: p.fedAuthAccessToken, - Nonce: fields[preloginNONCEOPT], + case login.FedAuthLibrary == fedAuthLibrarySecurityToken: + if p.logFlags&logDebug != 0 { + log.Println("Starting federated authentication using security token") + } + + if p.securityTokenProvider == nil { + return nil, fmt.Errorf("No security token provider configured to support federated authentication") + } + + login.FedAuthToken, err = p.securityTokenProvider(ctx) + if err != nil { + if p.logFlags&logDebug != 0 { + log.Printf("Failed to retrieve service principal token for federated authentication security token library: %v", err) + } + return nil, err + } + case login.FedAuthLibrary == fedAuthLibraryADAL && login.FedAuthADALWorkflow == fedAuthADALWorkflowPassword: + if p.logFlags&logDebug != 0 { + log.Printf("Starting ADAL username/password workflow for user %s", p.user) + } + login.UserName = p.user + login.Password = p.password + case login.FedAuthLibrary == fedAuthLibraryADAL && login.FedAuthADALWorkflow == fedAuthADALWorkflowMSI: + if p.logFlags&logDebug != 0 { + log.Println("Starting ADAL managed service identity (MSI) workflow") + } + case login.FedAuthLibrary == fedAuthLibraryADAL: + return nil, fmt.Errorf("Unsupported ADAL workflow type 0x%02x", int(login.FedAuthADALWorkflow)) + case auth_ok: + if p.logFlags&logDebug != 0 { + log.Println("Starting SSPI login") } - login.FeatureExt.Add(featurext) - case authOk: login.SSPI, err = auth.InitialBytes() if err != nil { return nil, err @@ -975,6 +1084,7 @@ initiate_connection: login.OptionFlags2 |= fIntSecurity defer auth.Free() default: + // Default to SQL server authentication with user and password login.UserName = p.user login.Password = p.password } @@ -1007,6 +1117,23 @@ initiate_connection: } sspi_msg = nil } + case fedAuthInfoStruct: + // For ADAL workflows this contains the STS URL and server SPN. + // If received outside of an ADAL workflow, ignore. + if login.FedAuthLibrary == fedAuthLibraryADAL { + if p.activeDirectoryTokenProvider == nil { + return nil, fmt.Errorf("No Active Directory token provider configured to support federated authentication") + } + login.FedAuthToken, err = p.activeDirectoryTokenProvider(ctx, token.ServerSPN, token.STSURL) + if err != nil { + return nil, err + } + // Now need to send the token as a FEDINFO packet + err = sendFedAuthInfo(outbuf, login) + if err != nil { + return nil, err + } + } case loginAckStruct: success = true sess.loginAck = token diff --git a/tds_test.go b/tds_test.go index e725a668..9441a642 100644 --- a/tds_test.go +++ b/tds_test.go @@ -89,24 +89,26 @@ func TestSendLoginWithFeatureExt(t *testing.T) { Database: "database", ClientLCID: 0x204, } - login.FeatureExt.Add(&featureExtFedAuthSTS{ - FedAuthToken: "fedauthtoken", - }) + login.FedAuthLibrary = fedAuthLibrarySecurityToken + login.FedAuthToken = "fedauthtoken" err := sendLogin(buf, login) if err != nil { t.Error("sendLogin should succeed") } ref := []byte{ - 16, 1, 0, 223, 0, 0, 1, 0, 215, 0, 0, 0, 4, 0, 0, 116, 0, 16, 0, 0, 0, 1, - 6, 1, 100, 0, 0, 0, 0, 0, 0, 0, 224, 0, 0, 24, 16, 255, 255, 255, 4, 2, 0, - 0, 94, 0, 7, 0, 108, 0, 0, 0, 108, 0, 0, 0, 108, 0, 7, 0, 122, 0, 10, 0, 176, - 0, 4, 0, 142, 0, 7, 0, 156, 0, 2, 0, 160, 0, 8, 0, 18, 52, 86, 120, 144, 171, - 176, 0, 0, 0, 176, 0, 0, 0, 176, 0, 0, 0, 0, 0, 0, 0, 115, 0, 117, 0, 98, - 0, 100, 0, 101, 0, 118, 0, 49, 0, 97, 0, 112, 0, 112, 0, 110, 0, 97, 0, - 109, 0, 101, 0, 115, 0, 101, 0, 114, 0, 118, 0, 101, 0, 114, 0, 110, 0, 97, - 0, 109, 0, 101, 0, 108, 0, 105, 0, 98, 0, 114, 0, 97, 0, 114, 0, 121, 0, 101, - 0, 110, 0, 100, 0, 97, 0, 116, 0, 97, 0, 98, 0, 97, 0, 115, 0, 101, 0, 180, 0, - 0, 0, 2, 29, 0, 0, 0, 2, 24, 0, 0, 0, 102, 0, 101, 0, 100, 0, 97, 0, 117, 0, + 16, 1, 0, 223, 0, 0, 1, 0, 215, 0, 0, 0, 4, 0, 0, 116, + 0, 16, 0, 0, 0, 1, 6, 1, 100, 0, 0, 0, 0, 0, 0, 0, + 224, 0, 0, 24, 16, 255, 255, 255, 4, 2, 0, 0, 94, 0, 7, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 108, 0, 7, 0, 122, 0, 10, 0, + 142, 0, 4, 0, 146, 0, 7, 0, 160, 0, 2, 0, 164, 0, 8, 0, + 18, 52, 86, 120, 144, 171, 180, 0, 0, 0, 180, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 115, 0, 117, 0, 98, 0, 100, 0, 101, 0, + 118, 0, 49, 0, 97, 0, 112, 0, 112, 0, 110, 0, 97, 0, 109, 0, + 101, 0, 115, 0, 101, 0, 114, 0, 118, 0, 101, 0, 114, 0, 110, 0, + 97, 0, 109, 0, 101, 0, 180, 0, 0, 0, 108, 0, 105, 0, 98, 0, + 114, 0, 97, 0, 114, 0, 121, 0, 101, 0, 110, 0, 100, 0, 97, 0, + 116, 0, 97, 0, 98, 0, 97, 0, 115, 0, 101, 0, 2, 29, 0, 0, + 0, 2, 24, 0, 0, 0, 102, 0, 101, 0, 100, 0, 97, 0, 117, 0, 116, 0, 104, 0, 116, 0, 111, 0, 107, 0, 101, 0, 110, 0, 255} out := memBuf.Bytes() if !bytes.Equal(ref, out) { diff --git a/token.go b/token.go index 25385e89..ef880935 100644 --- a/token.go +++ b/token.go @@ -6,12 +6,13 @@ import ( "errors" "fmt" "io" + "io/ioutil" "net" "strconv" "strings" ) -//go:generate stringer -type token +//go:generate go run golang.org/x/tools/cmd/stringer -type token type token byte @@ -29,6 +30,7 @@ const ( tokenNbcRow token = 210 // 0xd2 tokenEnvChange token = 227 // 0xE3 tokenSSPI token = 237 // 0xED + tokenFedAuthInfo token = 238 // 0xEE tokenDone token = 253 // 0xFD tokenDoneProc token = 254 tokenDoneInProc token = 255 @@ -70,6 +72,11 @@ const ( envRouting = 20 ) +const ( + fedAuthInfoSTSURL = 0x01 + fedAuthInfoSPN = 0x02 +) + // COLMETADATA flags // https://msdn.microsoft.com/en-us/library/dd357363.aspx const ( @@ -425,6 +432,78 @@ func parseSSPIMsg(r *tdsBuffer) sspiMsg { return sspiMsg(buf) } +type fedAuthInfoStruct struct { + STSURL string + ServerSPN string +} + +type fedAuthInfoOpt struct { + fedAuthInfoID byte + dataLength, dataOffset uint32 +} + +func parseFedAuthInfo(r *tdsBuffer) fedAuthInfoStruct { + size := r.uint32() + + var STSURL, SPN string + var err error + + // Each fedAuthInfoOpt is one byte to indicate the info ID, + // then a four byte offset and a four byte length. + count := r.uint32() + offset := uint32(4) + opts := make([]fedAuthInfoOpt, count) + + for i := uint32(0); i < count; i++ { + fedAuthInfoID := r.byte() + dataLength := r.uint32() + dataOffset := r.uint32() + offset += 1 + 4 + 4 + + opts[i] = fedAuthInfoOpt{ + fedAuthInfoID: fedAuthInfoID, + dataLength: dataLength, + dataOffset: dataOffset, + } + } + + data := make([]byte, size-offset) + r.ReadFull(data) + + for i := uint32(0); i < count; i++ { + if opts[i].dataOffset < offset { + badStreamPanicf("Fed auth info opt stated data offset %d is before data begins in packet at %d", + opts[i].dataOffset, offset) + // returns via panic + } + + if opts[i].dataOffset+opts[i].dataLength > size { + badStreamPanicf("Fed auth info opt stated data length %d added to stated offset exceeds size of packet %d", + opts[i].dataOffset+opts[i].dataLength, size) + // returns via panic + } + + optData := data[opts[i].dataOffset-offset : opts[i].dataOffset-offset+opts[i].dataLength] + switch opts[i].fedAuthInfoID { + case fedAuthInfoSTSURL: + STSURL, err = ucs22str(optData) + case fedAuthInfoSPN: + SPN, err = ucs22str(optData) + default: + err = fmt.Errorf("Unexpected fed auth info opt ID %d", int(opts[i].fedAuthInfoID)) + } + + if err != nil { + badStreamPanic(err) + } + } + + return fedAuthInfoStruct{ + STSURL: STSURL, + ServerSPN: SPN, + } +} + type loginAckStruct struct { Interface uint8 TDSVersion uint32 @@ -449,19 +528,43 @@ func parseLoginAck(r *tdsBuffer) loginAckStruct { } // https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/2eb82f8e-11f0-46dc-b42d-27302fa4701a -func parseFeatureExtAck(r *tdsBuffer) { - // at most 1 featureAck per feature in featureExt - // go-mssqldb will add at most 1 feature, the spec defines 7 different features - for i := 0; i < 8; i++ { - featureID := r.byte() // FeatureID - if featureID == 0xff { - return +type fedAuthAckStruct struct { + Nonce []byte + Signature []byte +} + +func parseFeatureExtAck(r *tdsBuffer) map[byte]interface{} { + ack := map[byte]interface{}{} + + for feature := r.byte(); feature != featExtTERMINATOR; feature = r.byte() { + length := r.uint32() + + switch feature { + case featExtFEDAUTH: + // In theory we need to know the federated authentication library to + // know how to parse, but the alternatives provide compatible structures. + fedAuthAck := fedAuthAckStruct{} + if length >= 32 { + fedAuthAck.Nonce = make([]byte, 0, 32) + r.ReadFull(fedAuthAck.Nonce) + length -= 32 + } + if length >= 32 { + fedAuthAck.Signature = make([]byte, 0, 32) + r.ReadFull(fedAuthAck.Signature) + length -= 32 + } + ack[feature] = fedAuthAck + + } + + // Skip unprocessed bytes + if length > 0 { + io.CopyN(ioutil.Discard, r, int64(length)) } - size := r.uint32() // FeatureAckDataLen - d := make([]byte, size) - r.ReadFull(d) } - panic("parsed more than 7 featureAck's, protocol implementation error?") + + return ack } // http://msdn.microsoft.com/en-us/library/dd357363.aspx @@ -588,6 +691,9 @@ func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[strin case tokenSSPI: ch <- parseSSPIMsg(sess.buf) return + case tokenFedAuthInfo: + ch <- parseFedAuthInfo(sess.buf) + return case tokenReturnStatus: returnStatus := parseReturnStatus(sess.buf) ch <- returnStatus @@ -595,7 +701,8 @@ func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[strin loginAck := parseLoginAck(sess.buf) ch <- loginAck case tokenFeatureExtAck: - parseFeatureExtAck(sess.buf) + featureExtAck := parseFeatureExtAck(sess.buf) + ch <- featureExtAck case tokenOrder: order := parseOrder(sess.buf) ch <- order diff --git a/token_string.go b/token_string.go index c075b23b..74389fdd 100644 --- a/token_string.go +++ b/token_string.go @@ -1,29 +1,46 @@ -// Code generated by "stringer -type token"; DO NOT EDIT +// Code generated by "stringer -type token"; DO NOT EDIT. package mssql -import "fmt" +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[tokenReturnStatus-121] + _ = x[tokenColMetadata-129] + _ = x[tokenOrder-169] + _ = x[tokenError-170] + _ = x[tokenInfo-171] + _ = x[tokenReturnValue-172] + _ = x[tokenLoginAck-173] + _ = x[tokenFeatureExtAck-174] + _ = x[tokenRow-209] + _ = x[tokenNbcRow-210] + _ = x[tokenEnvChange-227] + _ = x[tokenSSPI-237] + _ = x[tokenFedAuthInfo-238] + _ = x[tokenDone-253] + _ = x[tokenDoneProc-254] + _ = x[tokenDoneInProc-255] +} const ( _token_name_0 = "tokenReturnStatus" _token_name_1 = "tokenColMetadata" - _token_name_2 = "tokenOrdertokenErrortokenInfo" - _token_name_3 = "tokenLoginAck" - _token_name_4 = "tokenRowtokenNbcRow" - _token_name_5 = "tokenEnvChange" - _token_name_6 = "tokenSSPI" - _token_name_7 = "tokenDonetokenDoneProctokenDoneInProc" + _token_name_2 = "tokenOrdertokenErrortokenInfotokenReturnValuetokenLoginAcktokenFeatureExtAck" + _token_name_3 = "tokenRowtokenNbcRow" + _token_name_4 = "tokenEnvChange" + _token_name_5 = "tokenSSPItokenFedAuthInfo" + _token_name_6 = "tokenDonetokenDoneProctokenDoneInProc" ) var ( - _token_index_0 = [...]uint8{0, 17} - _token_index_1 = [...]uint8{0, 16} - _token_index_2 = [...]uint8{0, 10, 20, 29} - _token_index_3 = [...]uint8{0, 13} - _token_index_4 = [...]uint8{0, 8, 19} - _token_index_5 = [...]uint8{0, 14} - _token_index_6 = [...]uint8{0, 9} - _token_index_7 = [...]uint8{0, 9, 22, 37} + _token_index_2 = [...]uint8{0, 10, 20, 29, 45, 58, 76} + _token_index_3 = [...]uint8{0, 8, 19} + _token_index_5 = [...]uint8{0, 9, 25} + _token_index_6 = [...]uint8{0, 9, 22, 37} ) func (i token) String() string { @@ -32,22 +49,21 @@ func (i token) String() string { return _token_name_0 case i == 129: return _token_name_1 - case 169 <= i && i <= 171: + case 169 <= i && i <= 174: i -= 169 return _token_name_2[_token_index_2[i]:_token_index_2[i+1]] - case i == 173: - return _token_name_3 case 209 <= i && i <= 210: i -= 209 - return _token_name_4[_token_index_4[i]:_token_index_4[i+1]] + return _token_name_3[_token_index_3[i]:_token_index_3[i+1]] case i == 227: - return _token_name_5 - case i == 237: - return _token_name_6 + return _token_name_4 + case 237 <= i && i <= 238: + i -= 237 + return _token_name_5[_token_index_5[i]:_token_index_5[i+1]] case 253 <= i && i <= 255: i -= 253 - return _token_name_7[_token_index_7[i]:_token_index_7[i+1]] + return _token_name_6[_token_index_6[i]:_token_index_6[i+1]] default: - return fmt.Sprintf("token(%d)", i) + return "token(" + strconv.FormatInt(int64(i), 10) + ")" } } diff --git a/tvp_go19_db_test.go b/tvp_go19_db_test.go index 6cf42641..f4c2ddff 100644 --- a/tvp_go19_db_test.go +++ b/tvp_go19_db_test.go @@ -209,7 +209,8 @@ func TestTVP(t *testing.T) { bFalse := false floatValue64 := 0.123 floatValue32 := float32(-10.123) - timeNow := time.Now().UTC() + // Best datetime2 precision is 100 ns granularity + timeNow := time.Now().UTC().Truncate(100 * time.Nanosecond) param1 := []TvptableRow{ { PBinary: []byte("ccc"), @@ -462,7 +463,8 @@ func TestTVP_WithTag(t *testing.T) { bFalse := false floatValue64 := 0.123 floatValue32 := float32(-10.123) - timeNow := time.Now().UTC() + // Default (and maximum) datetime2 precision is 7 digits or 100ns + timeNow := time.Now().UTC().Truncate(100 * time.Nanosecond) param1 := []TvptableRowWithSkipTag{ { PBinary: []byte("ccc"),