From de9623115fc297f87326b0f26b4d3d8acb09d87b Mon Sep 17 00:00:00 2001 From: "Rose, William" Date: Sun, 8 Dec 2019 17:51:17 -0800 Subject: [PATCH 1/2] Added support for federated authentication to enable Azure AD authentication --- .gitignore | 5 + README.md | 8 +- conn_str.go | 51 ++ doc/how-to-test-azure-ad-authentication.md | 176 +++++++ examples/azuread/ad-service-principal-dsn.jq | 1 + examples/azuread/ad-system-assigned-id-dsn.jq | 1 + examples/azuread/ad-user-assigned-id-dsn.jq | 1 + examples/azuread/ad-user-password-dsn.jq | 1 + examples/azuread/azuread.go | 144 +++++ examples/azuread/environment-settings.jq | 20 + examples/azuread/testing.tf | 490 ++++++++++++++++++ examples/simple/simple.go | 39 +- fedauth.go | 101 ++++ fedauth_test.go | 189 +++++++ go.mod | 2 + go.sum | 25 + log_conn.go | 80 +++ tds.go | 403 +++++++++++--- token.go | 150 +++++- token_string.go | 66 ++- tvp_go19_db_test.go | 6 +- 21 files changed, 1818 insertions(+), 141 deletions(-) create mode 100644 .gitignore create mode 100644 doc/how-to-test-azure-ad-authentication.md create mode 100644 examples/azuread/ad-service-principal-dsn.jq create mode 100644 examples/azuread/ad-system-assigned-id-dsn.jq create mode 100644 examples/azuread/ad-user-assigned-id-dsn.jq create mode 100644 examples/azuread/ad-user-password-dsn.jq create mode 100644 examples/azuread/azuread.go create mode 100644 examples/azuread/environment-settings.jq create mode 100644 examples/azuread/testing.tf create mode 100644 fedauth.go create mode 100644 fedauth_test.go create mode 100644 log_conn.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..f110691d --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +.terraform +*.tfstate* +*.log +*.swp +*~ diff --git a/README.md b/README.md index b655176b..08049576 100644 --- a/README.md +++ b/README.md @@ -54,10 +54,14 @@ Other supported formats are listed below. * true - Server certificate is not checked. Default is true if encrypt is not specified. If trust server certificate is true, driver accepts any certificate presented by the server and any host name in that certificate. In this mode, TLS is susceptible to man-in-the-middle attacks. This should be used only for testing. * `certificate` - The file that contains the public key certificate of the CA that signed the SQL Server certificate. The specified certificate overrides the go platform specific CA certificates. * `hostNameInCertificate` - Specifies the Common Name (CN) in the server certificate. Default value is the server host. -* `ServerSPN` - The kerberos SPN (Service Principal Name) for the server. Default is MSSQLSvc/host:port. +* `ServerSPN` - The Kerberos SPN (Service Principal Name) for the server. Default is MSSQLSvc/host:port. * `Workstation ID` - The workstation name (default is the host name) * `ApplicationIntent` - Can be given the value `ReadOnly` to initiate a read-only connection to an Availability Group listener. The `database` must be specified when connecting with `Application Intent` set to `ReadOnly`. - +* `FedAuth` - The federated authentication scheme to use. + * `ActiveDirectoryApplication` - authenticates using an Azure Active Directory application client ID and client secret or certificate. Set the `user` to `client-ID@tenant-ID` and the `password` to the client secret. If using client certificates, provide the path to the PKCS#12 file containing the certificate and RSA private key in the `ClientCertPath` parameter, and set the `password` to the value needed to open the PKCS#12 file. + * `ActiveDirectoryMSI` - authenticates using the managed service identity (MSI) attached to the VM, or a specific user-assigned identity if a client ID is specified in the `user` field. + * `ActiveDirectoryPassword` - authenticates an Azure Active Directory user account in the form `user@domain.com` with a password. This method is not recommended for general use and does not support multi-factor authentication for accounts. + ### The connection string can be specified in one of three formats: diff --git a/conn_str.go b/conn_str.go index 4ff54b89..4d77f092 100644 --- a/conn_str.go +++ b/conn_str.go @@ -1,6 +1,7 @@ package mssql import ( + "errors" "fmt" "net" "net/url" @@ -13,6 +14,12 @@ import ( const defaultServerPort = 1433 +const ( + fedAuthActiveDirectoryPassword = "ActiveDirectoryPassword" + fedAuthActiveDirectoryMSI = "ActiveDirectoryMSI" + fedAuthActiveDirectoryApplication = "ActiveDirectoryApplication" +) + type connectParams struct { logFlags uint64 port uint64 @@ -37,6 +44,11 @@ type connectParams struct { failOverPartner string failOverPort uint64 packetSize uint16 + fedAuthLibrary byte + fedAuthADALWorkflow byte + aadTenantID string + aadClientCertPath string + tlsKeyLogFile string } func parseConnectParams(dsn string) (connectParams, error) { @@ -229,6 +241,45 @@ func parseConnectParams(dsn string) (connectParams, error) { } } + p.fedAuthLibrary = fedAuthLibraryReserved + fedAuth, ok := params["fedauth"] + if ok { + switch { + case strings.EqualFold(fedAuth, fedAuthActiveDirectoryPassword): + p.fedAuthLibrary = fedAuthLibraryADAL + p.fedAuthADALWorkflow = fedAuthADALWorkflowPassword + case strings.EqualFold(fedAuth, fedAuthActiveDirectoryMSI): + p.fedAuthLibrary = fedAuthLibraryADAL + p.fedAuthADALWorkflow = fedAuthADALWorkflowMSI + case strings.EqualFold(fedAuth, fedAuthActiveDirectoryApplication): + p.fedAuthLibrary = fedAuthLibrarySecurityToken + p.aadClientCertPath = params["clientcertpath"] + + // Split the user name into client id and tenant id at the @ symbol + at := strings.IndexRune(p.user, '@') + if at < 1 || at >= (len(p.user)-1) { + f := "Expecting user id to be clientID@tenantID: found '%s'" + return p, fmt.Errorf(f, p.user) + } + + p.aadTenantID = p.user[at+1:] + p.user = p.user[0:at] + default: + f := "Invalid federated authentication type '%s': expected %s, %s or %s" + return p, fmt.Errorf(f, fedAuth, fedAuthActiveDirectoryPassword, fedAuthActiveDirectoryMSI, fedAuthActiveDirectoryApplication) + } + + if p.disableEncryption { + f := "Encryption must not be disabled for federated authentication: encrypt='%s'" + return p, fmt.Errorf(f, encrypt) + } + } + + p.tlsKeyLogFile, ok = params["tls key log file"] + if ok && p.tlsKeyLogFile != "" && p.disableEncryption { + return p, errors.New("Cannot set tlsKeyLogFile when encryption is disabled") + } + return p, nil } diff --git a/doc/how-to-test-azure-ad-authentication.md b/doc/how-to-test-azure-ad-authentication.md new file mode 100644 index 00000000..56ccfb68 --- /dev/null +++ b/doc/how-to-test-azure-ad-authentication.md @@ -0,0 +1,176 @@ +# How to test Azure AD authentication + +To test Azure AD authentication requires an Azure SQL server configured with an +[Active Directory administrator](https://docs.microsoft.com/en-us/azure/sql-database/sql-database-aad-authentication-configure). +To test managed identity authentication, an Azure virtual machine configured with +[system-assigned and/or user-assigned identities](https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/qs-configure-portal-windows-vm) +is also required. + +The necessary resources can be set up through any means including the +[Azure Portal](https://portal.azure.com/), the Azure CLI, the Azure PowerShell cmdlets or +[Terraform](https://terraform.io/). To support these instructions, use the Terraform script at +[examples/azuread/testing.tf](../examples/azuread/testing.tf). + +## Create Azure infrastructure + +Download [Terraform](https://terraform.io/) to a location on your PATH. + +Log in to Azure using the Azure CLI. + +```console +you@workstation:~$ az login +you@workstation:~$ az account show +``` + +If your Azure account has access to multiple subscriptions, use +`az account set --subscription ` to choose the correct one. You will need to have at +least Contributor access to the portal and permissions in Azure Active Directory to create users +and grants. + +Check out this source repository (if you haven't already), change to the `examples/azuread` +directory and run Terraform: + +```console +you@workstation:~$ git clone -b azure-auth https://github.com/wrosenuance/go-mssqldb.git +you@workstation:~$ cd go-mssqldb/examples/azuread +you@workstation:azuread$ terraform init +you@workstation:azuread$ terraform apply +``` + +This will create an Azure resource group, a SQL server with a database, a virtual machine with a +system-assigned identity and user-assigned identity. Resources are named based on a random +prefix: to specify the prefix, use `terraform apply -var prefix=`. + +Upon successful completion, Terraform will display some key details of the infrastructure that has + been created. This includes the SSH key to access the VM, the administrator account and password + for the Azure SQL server, and all the relevant resource names. + +Save the settings to a JSON file: + +```console +you@workstation:azuread$ terraform output -json > settings.json +``` + +Save the SSH private key to a file: + +```console +you@workstation:azuread$ terraform output vm_user_ssh_private_key > ssh-identity +``` + +Copy the `settings.json` to the new VM: + +```console +you@workstation:azuread$ scp -i ssh-identity settings.json "$(terraform output vm_admin_name)@$(terraform output vm_ip_address):" +``` + +## Set up Azure Virtual Machine for testing + +SSH to the new VM to continue setup: + +```console +you@workstation:azuread$ ssh -i ssh-identity "$(terraform output vm_admin_name)@$(terraform output vm_ip_address)" +``` + +Once on the VM, update the system and install some basic packages: + +```console +azureuser@azure-vm:~$ sudo apt update -y +azureuser@azure-vm:~$ sudo apt upgrade -y +azureuser@azure-vm:~$ sudo apt install -y git openssl jq build-essential +azureuser@azure-vm:~$ sudo snap install go --classic +``` + +Install the Azure CLI using the script as shown below, or follow the +[manual install instructions](https://docs.microsoft.com/en-us/cli/azure/install-azure-cli-apt): + +```console +azureuser@azure-vm:~$ curl -sL https://aka.ms/InstallAzureCLIDeb | sudo bash +``` + +## Generate service principal certificate file + +Log in to Azure on the VM and set the subscription: + +```console +azureuser@azure-vm:~$ az login +azureuser@azure-vm:~$ az account set --subscription "$(jq -r '.subscription_id.value' settings.json)" +``` + +Use OpenSSL to create a new certificate and key in PEM format, using the : + +```console +azureuser@azure-vm:~$ openssl rand -writerand ~/.rnd +azureuser@azure-vm:~$ openssl req -x509 -nodes -newkey rsa:4096 -keyout client.key -out client.crt \ + -subj "/C=US/ST=MA/L=Boston/O=Global Security/OU=IT Department/CN=AD-SP" +azureuser@azure-vm:~$ openssl pkcs12 -export -out client.p12 -inkey client.key -in client.crt \ + -passout "pass:$(jq -r '.app_sp_client_secret.value' settings.json)" +azureuser@azure-vm:~$ export APP_SP_CLIENT_CERT="$PWD/client.p12" +``` + +Use the Azure CLI to add the client certificate to the application service principal: + +```console +azureuser@azure-vm:~$ az ad sp credential reset --append --cert @client.crt \ + --name "$(jq -r '.app_sp_client_id.value' settings.json)" +``` + +## Build source code and authorize users in database + +Clone this repository, build and run the `examples/azuread` helper that verifies the database +exists and sets up access for the system-assigned and user-assigned identities. + +```console +azureuser@azure-vm:~$ git clone -b azure-auth https://github.com/wrosenuance/go-mssqldb.git +azureuser@azure-vm:~$ cd go-mssqldb +azureuser@azure-vm:go-mssqldb$ go generate ./... +azureuser@azure-vm:go-mssqldb$ go build -o azuread ./examples/azuread +azureuser@azure-vm:go-mssqldb$ eval "$(jq -r -f examples/azuread/environment-settings.jq ../settings.json)" +azureuser@azure-vm:go-mssqldb$ ./azuread -fedauth ActiveDirectoryPassword +``` + +For some basic connectivity tests, use the `examples/simple` helper. Run these commands on the +Azure VM so that identity authentication is possible. + +```console +azureuser@azure-vm:go-mssqldb$ go build -o simple ./examples/simple +azureuser@azure-vm:go-mssqldb$ jq -r -f examples/azuread/ad-user-password-dsn.jq ../settings.json | + xargs ./simple -debug -dsn +azureuser@azure-vm:go-mssqldb$ jq -r -f examples/azuread/ad-service-principal-dsn.jq ../settings.json | + xargs ./simple -debug -dsn +azureuser@azure-vm:go-mssqldb$ jq -r -f examples/azuread/ad-system-assigned-id-dsn.jq ../settings.json | + xargs ./simple -debug -dsn +azureuser@azure-vm:go-mssqldb$ jq -r -f examples/azuread/ad-user-assigned-id-dsn.jq ../settings.json | + xargs ./simple -debug -dsn +``` + +## Running the integration tests + +Now that your environment is configured, you can run `go test`: + +```console +azureuser@azure-vm:go-mssqldb$ export SQLSERVER_DSN="$(jq -r -f examples/azuread/ad-system-assigned-id-dsn.jq ../settings.json)" +azureuser@azure-vm:go-mssqldb$ go test -coverprofile=coverage.out ./... +``` + +## Tear down environment + +After you complete your testing, use Terraform to destroy the infrastructure you created. + +```console +you@workstation:azuread$ terraform destroy +``` + +## Troubleshooting + +After Terraform runs you should be able to see resources that were created in the +[Azure Portal](https://portal.azure.com/). + +If the Azure SQL server is successfully created you can connect to it using the AD admin user +and password in SSMS. SSMS will prompt you to create firewall rules if they are missing. You +can read the AD admin user and password from the `settings.json`, or run: + +```console +you@workstation:azuread$ terraform output sql_ad_admin_user +you@workstation:azuread$ terraform output sql_ad_admin_password +``` + diff --git a/examples/azuread/ad-service-principal-dsn.jq b/examples/azuread/ad-service-principal-dsn.jq new file mode 100644 index 00000000..b2426037 --- /dev/null +++ b/examples/azuread/ad-service-principal-dsn.jq @@ -0,0 +1 @@ +@uri "sqlserver://\(.app_sp_client_id.value)%40\(.tenant_id.value):\(.app_sp_client_secret.value)@\(.sql_server_fqdn.value)?database=\(.sql_database_name.value)&encrypt=true&fedauth=ActiveDirectoryApplication" \ No newline at end of file diff --git a/examples/azuread/ad-system-assigned-id-dsn.jq b/examples/azuread/ad-system-assigned-id-dsn.jq new file mode 100644 index 00000000..288f2b7c --- /dev/null +++ b/examples/azuread/ad-system-assigned-id-dsn.jq @@ -0,0 +1 @@ +@uri "sqlserver://\(.sql_server_fqdn.value)?database=\(.sql_database_name.value)&encrypt=true&fedauth=ActiveDirectoryMSI" \ No newline at end of file diff --git a/examples/azuread/ad-user-assigned-id-dsn.jq b/examples/azuread/ad-user-assigned-id-dsn.jq new file mode 100644 index 00000000..df31d09d --- /dev/null +++ b/examples/azuread/ad-user-assigned-id-dsn.jq @@ -0,0 +1 @@ +@uri "sqlserver://\(.user_assigned_identity_client_id.value)@\(.sql_server_fqdn.value)?database=\(.sql_database_name.value)&encrypt=true&fedauth=ActiveDirectoryMSI" \ No newline at end of file diff --git a/examples/azuread/ad-user-password-dsn.jq b/examples/azuread/ad-user-password-dsn.jq new file mode 100644 index 00000000..beebc5e1 --- /dev/null +++ b/examples/azuread/ad-user-password-dsn.jq @@ -0,0 +1 @@ +@uri "sqlserver://\(.sql_ad_admin_user.value):\(.sql_ad_admin_password.value)@\(.sql_server_fqdn.value)?database=\(.sql_database_name.value)&encrypt=true&fedauth=ActiveDirectoryPassword" \ No newline at end of file diff --git a/examples/azuread/azuread.go b/examples/azuread/azuread.go new file mode 100644 index 00000000..8bea6f35 --- /dev/null +++ b/examples/azuread/azuread.go @@ -0,0 +1,144 @@ +package main + +import ( + "database/sql" + "flag" + "fmt" + "log" + "net/url" + "os" + "strings" + "time" + + _ "github.com/denisenkom/go-mssqldb" +) + +var ( + debug = flag.Bool("debug", false, "enable debugging") + server = flag.String("server", os.Getenv("SQL_SERVER"), "the database server name") + port = flag.Int("port", 1433, "the database port") + database = flag.String("database", os.Getenv("SQL_DATABASE"), "the database name") + user = flag.String("user", os.Getenv("SQL_AD_ADMIN_USER"), "the AD administrator user name") + password = flag.String("password", os.Getenv("SQL_AD_ADMIN_PASSWORD"), "the AD administrator password") + fedauth = flag.String("fedauth", "ActiveDirectoryPassword", "the federated authentication scheme to use") + appName = flag.String("app-name", os.Getenv("APP_NAME"), "the application name to authorize") + vmName = flag.String("vm-name", os.Getenv("VM_NAME"), "the system identity name to authorize for this VM") + uaName = flag.String("ua-name", os.Getenv("UA_NAME"), "the user assigned identity name to authorize for this VM") +) + +func createConnStr(database string) string { + connString := fmt.Sprintf("sqlserver://%s:%s@%s:%d?encrypt=true", + url.QueryEscape(*user), url.QueryEscape(*password), + url.QueryEscape(*server), *port) + + if database != "" && database != "master" { + connString = connString + "&database=" + url.QueryEscape(database) + } + + if *fedauth != "" { + connString = connString + "&fedauth=" + url.QueryEscape(*fedauth) + } + + if *debug { + connString = connString + "&log=127" + } + + return connString +} + +func createDatabaseIfNotExists() error { + // Check database exists by connecting to master on the Azure SQL server + connString := createConnStr("master") + + log.Printf("Open: %s\n", connString) + + conn, err := sql.Open("sqlserver", connString) + if err != nil { + return err + } + + defer conn.Close() + + if err = conn.Ping(); err != nil { + return err + } + + quoted := strings.ReplaceAll(*database, "]", "]]") + sql := "IF NOT EXISTS (SELECT 1 FROM sys.databases WHERE name = @p1)\n CREATE DATABASE [" + quoted + "] ( SERVICE_OBJECTIVE = 'S0' )" + log.Printf("Exec: @p1 = '%s'\n%s\n", *database, sql) + _, err = conn.Exec(sql, *database) + + return err +} + +func addExternalUserIfNotExists(user string) error { + connString := createConnStr(*database) + + log.Printf("Open: %s\n", connString) + + var conn *sql.DB + var err error + + for retry := 0; retry < 8; retry++ { + conn, err = sql.Open("sqlserver", connString) + if err == nil { + if err = conn.Ping(); err == nil { + break + } + } + log.Printf("Connection failed: %v", err) + log.Println("Retry in 15 seconds") + time.Sleep(15 * time.Second) + } + if err != nil { + log.Printf("Connection failed: %v", err) + log.Println("No further retries will be attempted") + return err + } + + defer conn.Close() + + quoted := strings.ReplaceAll(user, "]", "]]") + sql := "IF NOT EXISTS (SELECT 1 FROM sys.database_principals WHERE name = @p1)\n CREATE USER [" + quoted + "] FROM EXTERNAL PROVIDER" + log.Printf("Exec: @p1 = '%s'\n%s\n", user, sql) + _, err = conn.Exec(sql, user) + if err != nil { + return err + } + + sql = "IF IS_ROLEMEMBER('db_owner', @p1) = 0\n ALTER ROLE [db_owner] ADD MEMBER [" + quoted + "]" + log.Printf("Exec: @p1 = '%s'\n%s\n", user, sql) + _, err = conn.Exec(sql, user) + + return err +} + +func main() { + flag.Parse() + + err := createDatabaseIfNotExists() + if err != nil { + log.Fatalf("Unable to create database [%s]: %v", *database, err) + } + + if *vmName != "" { + err = addExternalUserIfNotExists(*vmName) + if err != nil { + log.Fatalf("Unable to create user for system-assigned identity [%s]: %v", *vmName, err) + } + } + + if *appName != "" { + err = addExternalUserIfNotExists(*appName) + if err != nil { + log.Fatalf("Unable to create user for application identity [%s]: %v", *appName, err) + } + } + + if *uaName != "" { + err = addExternalUserIfNotExists(*uaName) + if err != nil { + log.Fatalf("Unable to create user for user-assigned identity [%s]: %v", *uaName, err) + } + } +} diff --git a/examples/azuread/environment-settings.jq b/examples/azuread/environment-settings.jq new file mode 100644 index 00000000..a8c9192a --- /dev/null +++ b/examples/azuread/environment-settings.jq @@ -0,0 +1,20 @@ +# Convert Terraform settings to shell environment exports. +[ + "set -a", + "SQL_SERVER=" + (.sql_server_fqdn.value | @sh), + "SQL_ADMIN_USER=" + (.sql_admin_user.value | @sh), + "SQL_ADMIN_PASSWORD=" + (.sql_admin_password.value | @sh), + "SQL_AD_ADMIN_USER=" + (.sql_ad_admin_user.value | @sh), + "SQL_AD_ADMIN_PASSWORD=" + (.sql_ad_admin_password.value | @sh), + "APP_SP_CLIENT_ID=" + (.app_sp_client_id.value | @sh), + "APP_SP_CLIENT_SECRET=" + (.app_sp_client_secret.value | @sh), + "SQL_DATABASE=" + (.sql_database_name.value | @sh), + "APP_NAME=" + (.app_name.value | @sh), + "VM_NAME=" + (.vm_name.value | @sh), + "VM_CLIENT_ID=" + (.vm_client_id.value | @sh), + "UA_NAME=" + (.user_assigned_identity_name.value | @sh), + "UA_CLIENT_ID=" + (.user_assigned_identity_client_id.value | @sh), + "AZURE_SUBSCRIPTION_ID=" + (.subscription_id.value | @sh), + "AZURE_TENANT_ID=" + (.tenant_id.value | @sh), + "set +a" +] | map([.]) | .[] | @tsv diff --git a/examples/azuread/testing.tf b/examples/azuread/testing.tf new file mode 100644 index 00000000..cef943d8 --- /dev/null +++ b/examples/azuread/testing.tf @@ -0,0 +1,490 @@ +# +# Terraform setup for Azure SQL with Azure Active Directory authentication +# + +# +# Set up Terraform provider versions +# +provider "azuread" { + version = "~> 0.7" +} + +provider "azurerm" { + version = "~> 1.36" +} + +provider "http" { + version = "~> 1.1" +} + +provider "random" { + version = "~> 2.2" +} + +provider "tls" { + version = "~> 2.1" +} + +# +# Variables +# +# These variables allow limited overrides to control the resource creation. +# To specify, run terraform apply -var name1=value1 [-var name2=value2]... +# E.g. terraform apply -var prefix=my-stuff +# will use "my-stuff" in place of the randomly generated ID that is used by default. +# +variable "prefix" { + description = "Prefix for Azure resource names" + type = string + default = "" +} + +variable "location" { + description = "Azure location for resources" + type = string + default = "East US" +} + +variable "vm_admin_name" { + description = "Name of administrative user on virtual machine" + type = string + default = "azureuser" +} + +variable "ssh_key" { + description = "Path to RSA SSH private key (unencrypted)" + type = string + default = "~/.ssh/id_rsa" +} + +variable "workstation_ip" { + description = "IP address of this workstation to add to SQL server firewall rules" + type = string + default = "" +} + +# +# If the prefix is not specified via the variable, a sixteen character alphanumeric suffix is +# generated and then the prefix is set to "go-mssql-test-" + +# +resource "random_string" "random_prefix" { + length = 16 + lower = true + number = true + upper = false + special = false +} + +# +# Set up a local variable to capture the prefix to use - either the user-specified from the +# variable, or else the generated name using the random string above. +# +# Some resource names (e.g. SQL server) are more restricted than others - e.g. hyphens are +# not permitted - so we create a restricted name prefix as well as a regular name prefix. +# +locals { + regular_name_prefix = var.prefix != "" ? var.prefix : "go-mssql-test-${random_string.random_prefix.result}" + restricted_name_prefix = var.prefix != "" ? lower(replace(var.prefix, "/[^A-Za-z0-9]/", "")) : "gomssqltest${random_string.random_prefix.result}" +} + +# +# SSH Key - generate if not available at the file named in the variable. +# Terraform will complain if var.ssh_key is empty as this is interpreted as referring to the +# current working directory, and that is not a file. Instead, if you want to avoid using an +# existing SSH key, make it a literal "no" or some other string that is not an existing file or +# directory. +# +data "tls_public_key" "file_ssh_key" { + count = fileexists(var.ssh_key) ? 1 : 0 + private_key_pem = fileexists(var.ssh_key) ? file(var.ssh_key) : "" +} + +resource "tls_private_key" "rand_ssh_key" { + algorithm = "ECDSA" +} + +locals { + private_key_pem = fileexists(var.ssh_key) ? data.tls_public_key.file_ssh_key.0.private_key_pem : tls_private_key.rand_ssh_key.private_key_pem + public_key_pem = fileexists(var.ssh_key) ? data.tls_public_key.file_ssh_key.0.public_key_pem : tls_private_key.rand_ssh_key.public_key_pem + public_key_openssh = fileexists(var.ssh_key) ? data.tls_public_key.file_ssh_key.0.public_key_openssh : tls_private_key.rand_ssh_key.public_key_openssh +} + +# +# Retrieve tenant, subscription and default domain information based on the current Azure login. +# +data "azurerm_client_config" "current" { +} + +data "azurerm_subscription" "current" { +} + +data "azuread_domains" "current" { + only_default = "true" +} + +# +# Use ipify.org to determine workstation IP if not provided. +# If this guesses incorrectly, specify your workstation IP with -var worstation_ip= +# when you run terraform apply. +# +data "http" "workstation_ip" { + url = "https://api.ipify.org/" +} + +locals { + workstation_ip = var.workstation_ip != "" ? var.workstation_ip : chomp(data.http.workstation_ip.body) +} + +# +# Set up the Azure resource group for all the test resources. +# +resource "azurerm_resource_group" "rg" { + name = "${local.regular_name_prefix}-rg" + location = var.location +} + +# +# Set up an AD User to use as AD Administrator for the Azure SQL server. +# +# Using a regular user account makes it simpler to log in as the user with SSMS or the Go +# driver when setting up the other permissions for the identities that will be tested. +# It appears to although you can make the AD Administrator a service principal, doing so +# leads to issues during logins that do not occur when the AD Administrator is a normal +# AD User account. +# +resource "random_password" "sql_ad_admin_sp_password" { + length = 32 + special = true +} + +resource "azuread_user" "sql_ad_admin" { + user_principal_name = "SQLAdmin.${local.restricted_name_prefix}@${data.azuread_domains.current.domains[0].domain_name}" + display_name = "SQL Admin for ${local.restricted_name_prefix}" + mail_nickname = "SQLAdmin.${local.restricted_name_prefix}" + password = random_password.sql_ad_admin_sp_password.result +} + +# +# Set up the Azure SQL Server +# +# A normal (non-AD) administrator username and password are also provisioned. However, it is +# not possible to create AD users without logging in via an AD-authenticated account, so this +# non-AD administrator is not able to create new AD user accounts. +# +resource "random_password" "sql_admin_password" { + length = 16 + special = true +} + +resource "azurerm_sql_server" "sql_server" { + name = local.restricted_name_prefix + resource_group_name = azurerm_resource_group.rg.name + location = azurerm_resource_group.rg.location + + version = "12.0" + administrator_login = "sql-admin" + administrator_login_password = random_password.sql_admin_password.result +} + +resource "azurerm_sql_active_directory_administrator" "sql_server" { + server_name = azurerm_sql_server.sql_server.name + resource_group_name = azurerm_sql_server.sql_server.resource_group_name + login = "sql-ad-admin" + tenant_id = data.azurerm_client_config.current.tenant_id + object_id = azuread_user.sql_ad_admin.id +} + +resource "azurerm_sql_firewall_rule" "sql_server_allow_azure" { + name = "AllowAzureAccess" + server_name = azurerm_sql_server.sql_server.name + resource_group_name = azurerm_sql_server.sql_server.resource_group_name + start_ip_address = "0.0.0.0" + end_ip_address = "0.0.0.0" +} + +resource "azurerm_sql_firewall_rule" "sql_server_allow_workstation" { + name = "AllowWorkstationAccess" + server_name = azurerm_sql_server.sql_server.name + resource_group_name = azurerm_sql_server.sql_server.resource_group_name + start_ip_address = local.workstation_ip + end_ip_address = local.workstation_ip +} + +# +# Set up the test database on the Azure SQL server +# +resource "azurerm_sql_database" "sql_db" { + name = "go-mssqldb" + + server_name = azurerm_sql_server.sql_server.name + resource_group_name = azurerm_sql_server.sql_server.resource_group_name + location = azurerm_sql_server.sql_server.location + + requested_service_objective_name = "S0" +} + +# +# Create a service principal that will be granted access to the database, +# representing an application login to the database. +# +resource "azuread_application" "app" { + name = "${local.regular_name_prefix}-app" +} + +resource "azuread_service_principal" "app_sp" { + application_id = azuread_application.app.application_id + app_role_assignment_required = false +} + +resource "random_password" "app_sp_password" { + length = 32 + special = true +} + +resource "azuread_service_principal_password" "app_sp" { + service_principal_id = azuread_service_principal.app_sp.id + value = random_password.app_sp_password.result + end_date_relative = "8760h" +} + + +# +# Create a user-assigned identity that we will add to the VM in addition to the +# system-assigned identity. +# +resource "azurerm_user_assigned_identity" "vm_user_id" { + resource_group_name = azurerm_resource_group.rg.name + location = azurerm_resource_group.rg.location + + name = "${local.restricted_name_prefix}-user-id" +} + +# +# Create an Azure VM for testing managed identity authentication. +# +# To support the Azure VM, we need a virtual network, a subnet, the public IP, the network +# security group, and the network interface. The network security group allows incoming SSH +# from the anywhere on the internet. +# +resource "azurerm_virtual_network" "vm_vnet" { + name = "${local.regular_name_prefix}-vnet" + resource_group_name = azurerm_resource_group.rg.name + location = azurerm_resource_group.rg.location + address_space = ["10.0.0.0/16"] +} + +resource "azurerm_subnet" "vm_subnet" { + name = "${local.regular_name_prefix}-vm-sn" + resource_group_name = azurerm_resource_group.rg.name + virtual_network_name = azurerm_virtual_network.vm_vnet.name + address_prefix = "10.0.2.0/24" +} + +resource "azurerm_public_ip" "vm_ip" { + name = "${local.regular_name_prefix}-vm-ip" + resource_group_name = azurerm_resource_group.rg.name + location = azurerm_resource_group.rg.location + allocation_method = "Dynamic" + idle_timeout_in_minutes = 30 +} + +resource "azurerm_network_security_group" "vm_nsg" { + name = "${local.regular_name_prefix}-vm-nsg" + resource_group_name = azurerm_resource_group.rg.name + location = azurerm_resource_group.rg.location + + security_rule { + name = "SSH" + priority = 1001 + direction = "Inbound" + access = "Allow" + protocol = "Tcp" + source_port_range = "*" + destination_port_range = "22" + source_address_prefix = "*" + destination_address_prefix = "*" + } +} + +resource "azurerm_network_interface" "vm_nic" { + name = "${local.regular_name_prefix}-vm-nic" + resource_group_name = azurerm_resource_group.rg.name + location = azurerm_resource_group.rg.location + network_security_group_id = azurerm_network_security_group.vm_nsg.id + + ip_configuration { + name = "${local.regular_name_prefix}-vm-nic-config" + subnet_id = azurerm_subnet.vm_subnet.id + private_ip_address_allocation = "Dynamic" + public_ip_address_id = azurerm_public_ip.vm_ip.id + } +} + +# +# Given the networking setup, now create the Azure VM +# +resource "azurerm_virtual_machine" "vm" { + name = "${local.regular_name_prefix}-vm" + resource_group_name = azurerm_resource_group.rg.name + location = azurerm_resource_group.rg.location + network_interface_ids = [azurerm_network_interface.vm_nic.id] + vm_size = "Standard_B1s" + + storage_os_disk { + name = "${local.regular_name_prefix}-vm-os" + caching = "ReadWrite" + create_option = "FromImage" + managed_disk_type = "Standard_LRS" + } + + storage_image_reference { + publisher = "Canonical" + offer = "UbuntuServer" + sku = "18.04-LTS" + version = "latest" + } + + os_profile { + computer_name = "${local.regular_name_prefix}-vm" + admin_username = var.vm_admin_name + } + + os_profile_linux_config { + disable_password_authentication = true + ssh_keys { + path = "/home/${var.vm_admin_name}/.ssh/authorized_keys" + key_data = local.public_key_openssh + } + } + + # Configure the VM with both SystemAssigned and a UserAssigned identity + identity { + type = "SystemAssigned, UserAssigned" + identity_ids = [azurerm_user_assigned_identity.vm_user_id.id] + } +} + +# Retrieve the application ID corresponding to the service principal ID assigned to the VM. +data "azuread_service_principal" "vm_sp" { + object_id = azurerm_virtual_machine.vm.identity.0.principal_id +} + +# Wait for public IP to be assigned after VM is created so we can report it in the outputs. +data "azurerm_public_ip" "vm_ip" { + name = azurerm_public_ip.vm_ip.name + resource_group_name = azurerm_virtual_machine.vm.resource_group_name +} + +# +# After provisioning or refreshing, Terraform will populate these outputs. +# These capture the necessary pieces of information to access the new infrastructure. +# +output "tenant_id" { + description = "Azure tenant ID" + value = data.azurerm_client_config.current.tenant_id +} + +output "subscription_id" { + description = "Azure subscription ID" + value = data.azurerm_client_config.current.subscription_id +} + +output "sql_server_name" { + description = "Azure SQL server name" + value = azurerm_sql_server.sql_server.name +} + +output "sql_server_fqdn" { + description = "Azure SQL server domain name" + value = azurerm_sql_server.sql_server.fully_qualified_domain_name +} + +output "sql_ad_admin_user" { + description = "Azure SQL administrator name (AD authentication)" + value = azuread_user.sql_ad_admin.user_principal_name +} + +output "sql_ad_admin_password" { + description = "Azure SQL administrator password (AD authentication)" + value = random_password.sql_ad_admin_sp_password.result + sensitive = true +} + +output "sql_admin_user" { + description = "Azure SQL administrator name (SQL server authentication)" + value = azurerm_sql_server.sql_server.administrator_login +} + +output "sql_admin_password" { + description = "Azure SQL administrator password (SQL server authentication)" + value = random_password.sql_admin_password.result + sensitive = true +} + +output "sql_database_name" { + description = "Azure SQL database name" + value = azurerm_sql_database.sql_db.name +} + +output "vm_name" { + description = "Azure virtual machine name" + value = azurerm_virtual_machine.vm.name +} + +output "vm_client_id" { + description = "Azure VM system-assigned identity client ID" + value = data.azuread_service_principal.vm_sp.application_id +} + +output "vm_principal_id" { + description = "Azure VM system-assigned identity principal ID" + value = azurerm_virtual_machine.vm.identity.0.principal_id +} + +output "vm_ip_address" { + description = "Azure virtual machine public IP" + value = data.azurerm_public_ip.vm_ip.ip_address +} + +output "vm_admin_name" { + description = "Azure virtual machine admin user name" + value = var.vm_admin_name +} + +output "vm_user_ssh_private_key" { + description = "Azure virtual machine admin user private SSH key" + value = local.private_key_pem + sensitive = true +} + +output "vm_user_ssh_openssh_key" { + description = "Azure virtual machine admin user SSH public key" + value = local.public_key_openssh + sensitive = true +} + +output "app_sp_client_id" { + description = "Service principal client ID for application user" + value = azuread_application.app.application_id +} + +output "app_name" { + description = "Service principal name for application user" + value = azuread_application.app.name +} + +output "app_sp_client_secret" { + description = "Service principal client secret for application user" + value = random_password.app_sp_password.result + sensitive = true +} + +output "user_assigned_identity_name" { + description = "User-assigned identity for the Azure virtual machine" + value = azurerm_user_assigned_identity.vm_user_id.name +} + +output "user_assigned_identity_client_id" { + description = "User-assigned identity client ID" + value = azurerm_user_assigned_identity.vm_user_id.client_id +} diff --git a/examples/simple/simple.go b/examples/simple/simple.go index 67f88aa4..2c8965c2 100644 --- a/examples/simple/simple.go +++ b/examples/simple/simple.go @@ -5,12 +5,16 @@ import ( "flag" "fmt" "log" + "net/url" + "os" _ "github.com/denisenkom/go-mssqldb" ) var ( + database = flag.String("database", "", "the database name") debug = flag.Bool("debug", false, "enable debugging") + dsn = flag.String("dsn", os.Getenv("SQLSERVER_DSN"), "complete SQL DSN") password = flag.String("password", "", "the database password") port *int = flag.Int("port", 1433, "the database port") server = flag.String("server", "", "the database server") @@ -20,24 +24,35 @@ var ( func main() { flag.Parse() - if *debug { - fmt.Printf(" password:%s\n", *password) - fmt.Printf(" port:%d\n", *port) - fmt.Printf(" server:%s\n", *server) - fmt.Printf(" user:%s\n", *user) + var connString string + + if *dsn == "" { + if *debug { + fmt.Printf(" server: %s\n", *server) + fmt.Printf(" port: %d\n", *port) + fmt.Printf(" user: %s\n", *user) + fmt.Printf(" password: %s\n", *password) + fmt.Printf(" database: %s\n", *database) + } + + connString = fmt.Sprintf("sqlserver://%s:%s@%s:%d?database=%s&encrypt=true", + url.QueryEscape(*user), url.QueryEscape(*password), + url.QueryEscape(*server), *port, url.QueryEscape(*database)) + } else { + connString = *dsn } - connString := fmt.Sprintf("server=%s;user id=%s;password=%s;port=%d", *server, *user, *password, *port) if *debug { - fmt.Printf(" connString:%s\n", connString) + fmt.Printf(" dsn: %s\n", connString) } + conn, err := sql.Open("mssql", connString) if err != nil { log.Fatal("Open connection failed:", err.Error()) } defer conn.Close() - stmt, err := conn.Prepare("select 1, 'abc'") + stmt, err := conn.Prepare("select 1, 'abc', suser_name()") if err != nil { log.Fatal("Prepare failed:", err.Error()) } @@ -46,12 +61,14 @@ func main() { row := stmt.QueryRow() var somenumber int64 var somechars string - err = row.Scan(&somenumber, &somechars) + var someuser string + err = row.Scan(&somenumber, &somechars, &someuser) if err != nil { log.Fatal("Scan failed:", err.Error()) } - fmt.Printf("somenumber:%d\n", somenumber) - fmt.Printf("somechars:%s\n", somechars) + fmt.Printf("number: %d\n", somenumber) + fmt.Printf("chars: %s\n", somechars) + fmt.Printf("user: %s\n", someuser) fmt.Printf("bye\n") } diff --git a/fedauth.go b/fedauth.go new file mode 100644 index 00000000..8dbfcf44 --- /dev/null +++ b/fedauth.go @@ -0,0 +1,101 @@ +package mssql + +import ( + "context" + "crypto/rsa" + "crypto/x509" + "errors" + "fmt" + "io/ioutil" + + "github.com/Azure/go-autorest/autorest/adal" + "golang.org/x/crypto/pkcs12" +) + +const ( + activeDirectoryEndpoint = "https://login.microsoftonline.com/" + azureSQLResource = "https://database.windows.net/" + driverClientID = "7f98cb04-cd1e-40df-9140-3bf7e2cea4db" +) + +func fedAuthGetClientCertificate(clientCertPath, clientCertPassword string) (*x509.Certificate, *rsa.PrivateKey, error) { + pkcs, err := ioutil.ReadFile(clientCertPath) + if err != nil { + return nil, nil, fmt.Errorf("Failed to read the AD client certificate from path %s: %v", clientCertPath, err) + } + + privateKey, certificate, err := pkcs12.Decode(pkcs, clientCertPassword) + if err != nil { + return nil, nil, fmt.Errorf("Failed to read the AD client certificate from path %s: %v", clientCertPath, err) + } + + rsaPrivateKey, isRsaKey := privateKey.(*rsa.PrivateKey) + if !isRsaKey { + return nil, nil, fmt.Errorf("AD client certificate at path %s must contain an RSA private key", clientCertPath) + } + + return certificate, rsaPrivateKey, nil +} + +func fedAuthGetAccessToken(ctx context.Context, resource, tenantID string, p connectParams, log optionalLogger) (accessToken string, err error) { + // The activeDirectoryEndpoint URL is used as a base against which the + // tenant ID is resolved. When the workflow provides a complete endpoint + // URL for the tenant, the URL resolution just returns that endpoint. + oauthConfig, err := adal.NewOAuthConfig(activeDirectoryEndpoint, tenantID) + if err != nil { + log.Printf("Failed to obtain OAuth configuration for endpoint %s and tenant %s: %v", activeDirectoryEndpoint, tenantID, err) + return "", err + } + + var token *adal.ServicePrincipalToken + if p.fedAuthLibrary == fedAuthLibrarySecurityToken { + // When the security token library is used, the token is obtained without input + // from the server, so the AD endpoint and Azure SQL resource URI are provided + // from the constants above. + if p.aadClientCertPath != "" { + var certificate *x509.Certificate + var rsaPrivateKey *rsa.PrivateKey + certificate, rsaPrivateKey, err = fedAuthGetClientCertificate(p.aadClientCertPath, p.password) + if err == nil { + token, err = adal.NewServicePrincipalTokenFromCertificate(*oauthConfig, p.user, certificate, rsaPrivateKey, azureSQLResource) + } + } else { + token, err = adal.NewServicePrincipalToken(*oauthConfig, p.user, p.password, azureSQLResource) + } + } else if p.fedAuthLibrary == fedAuthLibraryADAL { + // When the ADAL workflow is used, the server provides the endpoint (STS URL) + // and resource (server SPN) during the login process. The STS URL is passed + // as the tenant ID and has already been used to build the OAuth config. + if p.fedAuthADALWorkflow == fedAuthADALWorkflowPassword { + token, err = adal.NewServicePrincipalTokenFromUsernamePassword(*oauthConfig, driverClientID, p.user, p.password, resource) + + } else if p.fedAuthADALWorkflow == fedAuthADALWorkflowMSI { + // When using MSI, to request a specific client ID or user-assigned identity, + // provide the ID as the username. + var msiEndpoint string + msiEndpoint, err = adal.GetMSIEndpoint() + if err == nil { + if p.user == "" { + token, err = adal.NewServicePrincipalTokenFromMSI(msiEndpoint, resource) + } else { + token, err = adal.NewServicePrincipalTokenFromMSIWithUserAssignedID(msiEndpoint, resource, p.user) + } + } + } + } else { + return "", errors.New("Unsupported federated authentication library") + } + + if err != nil { + log.Printf("Failed to obtain service principal token for client id %s in tenant %s: %v", p.user, tenantID, err) + return "", err + } + + err = token.RefreshWithContext(ctx) + if err != nil { + log.Printf("Failed to refresh service principal token for client id %s in tenant %s: %v", p.user, tenantID, err) + return "", err + } + + return token.Token().AccessToken, nil +} diff --git a/fedauth_test.go b/fedauth_test.go new file mode 100644 index 00000000..7b6e1de1 --- /dev/null +++ b/fedauth_test.go @@ -0,0 +1,189 @@ +package mssql + +import ( + "context" + "database/sql" + "net/url" + "os" + "strings" + "testing" +) + +func checkAzureSQLEnvironment(fedAuth string, t *testing.T) (*url.URL, string) { + u := &url.URL{ + Scheme: "sqlserver", + Host: os.Getenv("SQL_SERVER"), + } + + if u.Host == "" { + t.Skip("Azure SQL Server name not provided in SQL_SERVER environment variable") + } + + database := os.Getenv("SQL_DATABASE") + if database == "" { + t.Skip("Azure SQL database name not provided in SQL_DATABASE environment variable") + } + + tenantID := os.Getenv("AZURE_TENANT_ID") + if tenantID == "" { + t.Skip("Azure tenant ID not provided in AZURE_TENANT_ID environment variable") + } + + query := u.Query() + + query.Add("database", database) + query.Add("encrypt", "true") + query.Add("fedauth", fedAuth) + + u.RawQuery = query.Encode() + + return u, tenantID +} + +func checkFedAuthUserPassword(t *testing.T) *url.URL { + u, _ := checkAzureSQLEnvironment(fedAuthActiveDirectoryPassword, t) + + username := os.Getenv("SQL_AD_ADMIN_USER") + password := os.Getenv("SQL_AD_ADMIN_PASSWORD") + + if username == "" || password == "" { + t.Skip("Username and password login requires SQL_AD_ADMIN_USER and SQL_AD_ADMIN_PASSWORD environment variables") + } + + u.User = url.UserPassword(username, password) + + return u +} + +func checkFedAuthAppPassword(t *testing.T) *url.URL { + u, tenantID := checkAzureSQLEnvironment(fedAuthActiveDirectoryApplication, t) + + appClientID := os.Getenv("APP_SP_CLIENT_ID") + appPassword := os.Getenv("APP_SP_CLIENT_SECRET") + + if appClientID == "" || appPassword == "" { + t.Skip("Application (service principal) login requires APP_SP_CLIENT_ID and APP_SP_CLIENT_SECRET environment variables") + } + + u.User = url.UserPassword(appClientID+"@"+tenantID, appPassword) + + return u +} + +func checkFedAuthAppCertPath(t *testing.T) *url.URL { + u := checkFedAuthAppPassword(t) + + appCertPath := os.Getenv("APP_SP_CLIENT_CERT") + if appCertPath == "" { + t.Skip("Application (service principal) certificate login requires APP_SP_CLIENT_CERT with path to certificate") + } + + query := u.Query() + query.Add("clientcertpath", appCertPath) + u.RawQuery = query.Encode() + + return u +} + +func checkFedAuthVMSystemID(t *testing.T) (*url.URL, string) { + u, tenantID := checkAzureSQLEnvironment(fedAuthActiveDirectoryMSI, t) + + vmClientID := os.Getenv("VM_CLIENT_ID") + if vmClientID == "" { + t.Skip("System-assigned identity login test requires VM_CLIENT_ID environment variable") + } + + return u, vmClientID + "@" + tenantID +} + +func checkFedAuthVMUserAssignedID(t *testing.T) (*url.URL, string) { + u, tenantID := checkAzureSQLEnvironment(fedAuthActiveDirectoryMSI, t) + + uaClientID := os.Getenv("UA_CLIENT_ID") + if uaClientID == "" { + t.Skip("User-assigned identity login test requires UA_CLIENT_ID environment variable") + } + + u.User = url.User(uaClientID) + + return u, uaClientID + "@" + tenantID +} + +func checkLoggedInUser(expected string, u *url.URL, t *testing.T) { + db, err := sql.Open("sqlserver", u.String()) + if err != nil { + t.Fatalf("Failed to open URL %v: %v", u, err) + } + + defer db.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + sql := "SELECT SUSER_NAME()" + + stmt, err := db.PrepareContext(ctx, sql) + if err != nil { + t.Fatalf("Failed to prepare query %s: %v", sql, err) + } + + defer stmt.Close() + + rows, err := stmt.QueryContext(ctx) + if err != nil { + t.Fatalf("Failed to fetch query result for %s: %v", sql, err) + } + + defer rows.Close() + + var username string + if !rows.Next() { + t.Fatalf("Empty result set for query %s", sql) + } + + err = rows.Scan(&username) + if err != nil { + t.Fatalf("Failed to fetch first row for %s: %v", sql, err) + } + + if !strings.EqualFold(username, expected) { + t.Fatalf("Expected username %s: actual: %s", expected, username) + } + + t.Logf("Logged in username %s matches expected %s", username, expected) +} + +func TestFedAuthWithUserAndPassword(t *testing.T) { + SetLogger(testLogger{t}) + u := checkFedAuthUserPassword(t) + + checkLoggedInUser(u.User.Username(), u, t) +} + +func TestFedAuthWithApplicationUsingPassword(t *testing.T) { + SetLogger(testLogger{t}) + u := checkFedAuthAppPassword(t) + + checkLoggedInUser(u.User.Username(), u, t) +} + +func TestFedAuthWithApplicationUsingCertificate(t *testing.T) { + SetLogger(testLogger{t}) + u := checkFedAuthAppCertPath(t) + + checkLoggedInUser(u.User.Username(), u, t) +} + +func TestFedAuthWithSystemAssignedIdentity(t *testing.T) { + u, vmName := checkFedAuthVMSystemID(t) + SetLogger(testLogger{t}) + + checkLoggedInUser(vmName, u, t) +} + +func TestFedAuthWithUserAssignedIdentity(t *testing.T) { + SetLogger(testLogger{t}) + u, uaName := checkFedAuthVMUserAssignedID(t) + + checkLoggedInUser(uaName, u, t) +} diff --git a/go.mod b/go.mod index ebc02ab8..2526b04a 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,8 @@ module github.com/denisenkom/go-mssqldb go 1.11 require ( + github.com/Azure/go-autorest/autorest/adal v0.8.0 github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c + golang.org/x/tools v0.0.0-20191206204035-259af5ff87bd // indirect ) diff --git a/go.sum b/go.sum index 1887801b..581ab1c2 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,30 @@ +github.com/Azure/go-autorest v13.3.0+incompatible h1:8Ix0VdeOllBx9jEcZ2Wb1uqWUpE1awmJiaHztwaJCPk= +github.com/Azure/go-autorest/autorest v0.9.0 h1:MRvx8gncNaXJqOoLmhNjUAKh33JJF8LyxPhomEtOsjs= +github.com/Azure/go-autorest/autorest v0.9.0/go.mod h1:xyHB1BMZT0cuDHU7I0+g046+BFDTQ8rEZB0s4Yfa6bI= +github.com/Azure/go-autorest/autorest v0.9.2 h1:6AWuh3uWrsZJcNoCHrCF/+g4aKPCU39kaMO6/qrnK/4= +github.com/Azure/go-autorest/autorest/adal v0.5.0/go.mod h1:8Z9fGy2MpX0PvDjB1pEgQTmVqjGhiHBW7RJJEciWzS0= +github.com/Azure/go-autorest/autorest/adal v0.8.0 h1:CxTzQrySOxDnKpLjFJeZAS5Qrv/qFPkgLjx5bOAi//I= +github.com/Azure/go-autorest/autorest/adal v0.8.0/go.mod h1:Z6vX6WXXuyieHAXwMj0S6HY6e6wcHn37qQMBQlvY3lc= +github.com/Azure/go-autorest/autorest/date v0.1.0/go.mod h1:plvfp3oPSKwf2DNjlBjWF/7vwR+cUD/ELuzDCXwHUVA= +github.com/Azure/go-autorest/autorest/date v0.2.0 h1:yW+Zlqf26583pE43KhfnhFcdmSWlm5Ew6bxipnr/tbM= +github.com/Azure/go-autorest/autorest/date v0.2.0/go.mod h1:vcORJHLJEh643/Ioh9+vPmf1Ij9AEBM5FuBIXLmIy0g= +github.com/Azure/go-autorest/autorest/mocks v0.1.0/go.mod h1:OTyCOPRA2IgIlWxVYxBee2F5Gr4kF2zd2J5cFRaIDN0= +github.com/Azure/go-autorest/autorest/mocks v0.2.0/go.mod h1:OTyCOPRA2IgIlWxVYxBee2F5Gr4kF2zd2J5cFRaIDN0= +github.com/Azure/go-autorest/autorest/mocks v0.3.0/go.mod h1:a8FDP3DYzQ4RYfVAxAN3SVSiiO77gL2j2ronKKP0syM= +github.com/Azure/go-autorest/logger v0.1.0/go.mod h1:oExouG+K6PryycPJfVSxi/koC6LSNgds39diKLz7Vrc= +github.com/Azure/go-autorest/tracing v0.5.0 h1:TRn4WjSnkcSy5AEG3pnbtFSwNtwzjr4VYyQflFE619k= +github.com/Azure/go-autorest/tracing v0.5.0/go.mod h1:r/s2XiOKccPW3HrqB+W0TQzfbtp2fGCgRFtBroKn4Dk= +github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM= +github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c h1:Vj5n4GlwjmQteupaxJ9+0FNOmBrHfq7vN4btdGoDZgI= golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/tools v0.0.0-20191206204035-259af5ff87bd h1:Zc7EU2PqpsNeIfOoVA7hvQX4cS3YDJEs5KlfatT3hLo= +golang.org/x/tools v0.0.0-20191206204035-259af5ff87bd/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/log_conn.go b/log_conn.go new file mode 100644 index 00000000..4777e4c1 --- /dev/null +++ b/log_conn.go @@ -0,0 +1,80 @@ +package mssql + +import ( + "encoding/hex" + "net" + "strings" + "time" +) + +type connLogger struct { + conn net.Conn + readKind, writeKind string + readCount, writeCount int + logger Logger +} + +var _ net.Conn = &connLogger{} + +func newConnLogger(conn net.Conn, kind string, logger Logger) net.Conn { + if len(kind) > 0 && !strings.HasPrefix(kind, " ") { + kind = " " + kind + } + + cl := &connLogger{ + conn: conn, + readKind: "R" + kind, + writeKind: "W" + kind, + logger: logger, + } + + return cl +} + +func (cl *connLogger) Read(p []byte) (n int, err error) { + n, err = cl.conn.Read(p) + + if n > 0 { + dump := hex.Dump(p) + cl.logger.Printf("%s %d\n%s", cl.readKind, cl.readCount, dump) + cl.readCount += n + } + + return +} + +func (cl *connLogger) Write(p []byte) (n int, err error) { + n, err = cl.conn.Write(p) + + if n > 0 { + dump := hex.Dump(p) + cl.logger.Printf("%s %d\n%s", cl.writeKind, cl.writeCount, dump) + cl.writeCount += n + } + + return +} + +func (cl *connLogger) Close() (err error) { + return cl.conn.Close() +} + +func (cl *connLogger) LocalAddr() net.Addr { + return cl.conn.LocalAddr() +} + +func (cl *connLogger) RemoteAddr() net.Addr { + return cl.conn.RemoteAddr() +} + +func (cl *connLogger) SetDeadline(t time.Time) error { + return cl.conn.SetDeadline(t) +} + +func (cl *connLogger) SetReadDeadline(t time.Time) error { + return cl.conn.SetReadDeadline(t) +} + +func (cl *connLogger) SetWriteDeadline(t time.Time) error { + return cl.conn.SetWriteDeadline(t) +} diff --git a/tds.go b/tds.go index 94198364..98c1186c 100644 --- a/tds.go +++ b/tds.go @@ -10,6 +10,7 @@ import ( "io" "io/ioutil" "net" + "os" "sort" "strconv" "strings" @@ -89,24 +90,27 @@ const ( // 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx packAttention = 6 - packBulkLoadBCP = 7 - packTransMgrReq = 14 - packNormal = 15 - packLogin7 = 16 - packSSPIMessage = 17 - packPrelogin = 18 + packBulkLoadBCP = 7 + packFedAuthToken = 8 + packTransMgrReq = 14 + packNormal = 15 + packLogin7 = 16 + packSSPIMessage = 17 + packPrelogin = 18 ) // prelogin fields // http://msdn.microsoft.com/en-us/library/dd357559.aspx const ( - preloginVERSION = 0 - preloginENCRYPTION = 1 - preloginINSTOPT = 2 - preloginTHREADID = 3 - preloginMARS = 4 - preloginTRACEID = 5 - preloginTERMINATOR = 0xff + preloginVERSION = 0 + preloginENCRYPTION = 1 + preloginINSTOPT = 2 + preloginTHREADID = 3 + preloginMARS = 4 + preloginTRACEID = 5 + preloginFEDAUTHREQUIRED = 6 + preloginNONCEOPT = 7 + preloginTERMINATOR = 0xff ) const ( @@ -116,6 +120,33 @@ const ( encryptReq = 3 // Encryption is required. ) +const ( + featExtSESSIONRECOVERY byte = 0x01 + featExtFEDAUTH byte = 0x02 + featExtCOLUMNENCRYPTION byte = 0x04 + featExtGLOBALTRANSACTIONS byte = 0x05 + featExtAZURESQLSUPPORT byte = 0x08 + featExtDATACLASSIFICATION byte = 0x09 + featExtUTF8SUPPORT byte = 0x0A + featExtTERMINATOR byte = 0xFF +) + +// Federated authentication library affects the login data structure and message sequence. +const ( + fedAuthLibraryLiveIDCompactToken = 0x00 + fedAuthLibrarySecurityToken = 0x01 + fedAuthLibraryADAL = 0x02 + + fedAuthLibraryReserved = 0x7F +) + +// Federated authentication ADAL workflow affects the mechanism used to authenticate. +const ( + fedAuthADALWorkflowPassword = 0x01 + fedAuthADALWorkflowIntegrated = 0x02 + fedAuthADALWorkflowMSI = 0x03 +) + type tdsSession struct { buf *tdsBuffer loginAck loginAckStruct @@ -137,6 +168,7 @@ const ( logParams = 16 logTransaction = 32 logDebug = 64 + logTraffic = 128 ) type columnStruct struct { @@ -238,6 +270,16 @@ const ( fIntSecurity = 0x80 ) +// OptionFlags3 +// http://msdn.microsoft.com/en-us/library/dd304019.aspx +const ( + fChangePassword = 1 + fSendYukonBinaryXML = 2 + fUserInstance = 4 + fUnknownCollationHandling = 8 + fExtension = 0x10 +) + // TypeFlags const ( // 4 bits for fSQLType @@ -246,29 +288,34 @@ const ( ) type login struct { - TDSVersion uint32 - PacketSize uint32 - ClientProgVer uint32 - ClientPID uint32 - ConnectionID uint32 - OptionFlags1 uint8 - OptionFlags2 uint8 - TypeFlags uint8 - OptionFlags3 uint8 - ClientTimeZone int32 - ClientLCID uint32 - HostName string - UserName string - Password string - AppName string - ServerName string - CtlIntName string - Language string - Database string - ClientID [6]byte - SSPI []byte - AtchDBFile string - ChangePassword string + TDSVersion uint32 + PacketSize uint32 + ClientProgVer uint32 + ClientPID uint32 + ConnectionID uint32 + OptionFlags1 uint8 + OptionFlags2 uint8 + TypeFlags uint8 + OptionFlags3 uint8 + ClientTimeZone int32 + ClientLCID uint32 + HostName string + UserName string + Password string + AppName string + ServerName string + CtlIntName string + Language string + Database string + ClientID [6]byte + SSPI []byte + AtchDBFile string + ChangePassword string + FedAuthLibrary byte + FedAuthEcho byte + FedAuthToken string + FedAuthNonce []byte + FedAuthADALWorkflow byte } type loginHeader struct { @@ -295,7 +342,7 @@ type loginHeader struct { ServerNameOffset uint16 ServerNameLength uint16 ExtensionOffset uint16 - ExtensionLenght uint16 + ExtensionLength uint16 CtlIntNameOffset uint16 CtlIntNameLength uint16 LanguageOffset uint16 @@ -357,42 +404,81 @@ func sendLogin(w *tdsBuffer, login login) error { database := str2ucs2(login.Database) atchdbfile := str2ucs2(login.AtchDBFile) changepassword := str2ucs2(login.ChangePassword) + fedauthtoken := str2ucs2(login.FedAuthToken) + + // Determine if any feature extensions need to be written so we know whether + // to include the option flag and offset to the data. + var featureExtLength uint32 + fedAuth := login.FedAuthLibrary == fedAuthLibrarySecurityToken || login.FedAuthLibrary == fedAuthLibraryADAL + if login.FedAuthLibrary == fedAuthLibrarySecurityToken { + // Each feature extension record is 1 byte to indicate the type of the + // feature extension, four bytes for the data size, then the data size. + // For the SecurityToken data, the size is one byte to indicate that + // it's the SecurityToken library, four bytes for the token length, + // then the token bytes. + featureExtLength += uint32(1 + 4 + 1 + 4 + len(fedauthtoken)) + } else if login.FedAuthLibrary == fedAuthLibraryADAL { + // In addition to the 1 + 4 bytes for the feature extension header, + // ADAL just requires one byte to indicate the library and one to + // set the workflow (password, integrated or MSI). + featureExtLength += uint32(1 + 4 + 1 + 1) + } + + // If any feature extension records are written, a final single-byte terminator + // record must also be written, and the fExtension flag must be set. + if featureExtLength > 0 { + featureExtLength++ + login.OptionFlags3 |= fExtension + } else { + login.OptionFlags3 &^= fExtension + } + hdr := loginHeader{ - TDSVersion: login.TDSVersion, - PacketSize: login.PacketSize, - ClientProgVer: login.ClientProgVer, - ClientPID: login.ClientPID, - ConnectionID: login.ConnectionID, - OptionFlags1: login.OptionFlags1, - OptionFlags2: login.OptionFlags2, - TypeFlags: login.TypeFlags, - OptionFlags3: login.OptionFlags3, - ClientTimeZone: login.ClientTimeZone, - ClientLCID: login.ClientLCID, - HostNameLength: uint16(utf8.RuneCountInString(login.HostName)), - UserNameLength: uint16(utf8.RuneCountInString(login.UserName)), - PasswordLength: uint16(utf8.RuneCountInString(login.Password)), - AppNameLength: uint16(utf8.RuneCountInString(login.AppName)), - ServerNameLength: uint16(utf8.RuneCountInString(login.ServerName)), - CtlIntNameLength: uint16(utf8.RuneCountInString(login.CtlIntName)), - LanguageLength: uint16(utf8.RuneCountInString(login.Language)), - DatabaseLength: uint16(utf8.RuneCountInString(login.Database)), - ClientID: login.ClientID, - SSPILength: uint16(len(login.SSPI)), - AtchDBFileLength: uint16(utf8.RuneCountInString(login.AtchDBFile)), - ChangePasswordLength: uint16(utf8.RuneCountInString(login.ChangePassword)), + TDSVersion: login.TDSVersion, + PacketSize: login.PacketSize, + ClientProgVer: login.ClientProgVer, + ClientPID: login.ClientPID, + ConnectionID: login.ConnectionID, + OptionFlags1: login.OptionFlags1, + OptionFlags2: login.OptionFlags2, + TypeFlags: login.TypeFlags, + OptionFlags3: login.OptionFlags3, + ClientTimeZone: login.ClientTimeZone, + ClientLCID: login.ClientLCID, + HostNameLength: uint16(utf8.RuneCountInString(login.HostName)), + AppNameLength: uint16(utf8.RuneCountInString(login.AppName)), + ServerNameLength: uint16(utf8.RuneCountInString(login.ServerName)), + CtlIntNameLength: uint16(utf8.RuneCountInString(login.CtlIntName)), + LanguageLength: uint16(utf8.RuneCountInString(login.Language)), + DatabaseLength: uint16(utf8.RuneCountInString(login.Database)), + ClientID: login.ClientID, + SSPILength: uint16(len(login.SSPI)), + AtchDBFileLength: uint16(utf8.RuneCountInString(login.AtchDBFile)), } offset := uint16(binary.Size(hdr)) hdr.HostNameOffset = offset offset += uint16(len(hostname)) - hdr.UserNameOffset = offset - offset += uint16(len(username)) - hdr.PasswordOffset = offset - offset += uint16(len(password)) + + if !fedAuth { + hdr.UserNameOffset = offset + hdr.UserNameLength = uint16(utf8.RuneCountInString(login.UserName)) + offset += uint16(len(username)) + hdr.PasswordOffset = offset + hdr.PasswordLength = uint16(utf8.RuneCountInString(login.Password)) + offset += uint16(len(password)) + } + hdr.AppNameOffset = offset offset += uint16(len(appname)) hdr.ServerNameOffset = offset offset += uint16(len(servername)) + + if featureExtLength > 0 { + hdr.ExtensionOffset = offset + hdr.ExtensionLength = 4 + offset += hdr.ExtensionLength + } + hdr.CtlIntNameOffset = offset offset += uint16(len(ctlintname)) hdr.LanguageOffset = offset @@ -403,9 +489,14 @@ func sendLogin(w *tdsBuffer, login login) error { offset += uint16(len(login.SSPI)) hdr.AtchDBFileOffset = offset offset += uint16(len(atchdbfile)) - hdr.ChangePasswordOffset = offset - offset += uint16(len(changepassword)) - hdr.Length = uint32(offset) + + if !fedAuth { + hdr.ChangePasswordOffset = offset + hdr.ChangePasswordLength = uint16(utf8.RuneCountInString(login.ChangePassword)) + offset += uint16(len(changepassword)) + } + + hdr.Length = uint32(offset) + featureExtLength var err error err = binary.Write(w, binary.LittleEndian, &hdr) if err != nil { @@ -415,13 +506,16 @@ func sendLogin(w *tdsBuffer, login login) error { if err != nil { return err } - _, err = w.Write(username) - if err != nil { - return err - } - _, err = w.Write(password) - if err != nil { - return err + + if !fedAuth { + _, err = w.Write(username) + if err != nil { + return err + } + _, err = w.Write(password) + if err != nil { + return err + } } _, err = w.Write(appname) if err != nil { @@ -431,6 +525,12 @@ func sendLogin(w *tdsBuffer, login login) error { if err != nil { return err } + if featureExtLength > 0 { + err = binary.Write(w, binary.LittleEndian, uint32(offset)) + if err != nil { + return err + } + } _, err = w.Write(ctlintname) if err != nil { return err @@ -451,10 +551,61 @@ func sendLogin(w *tdsBuffer, login login) error { if err != nil { return err } - _, err = w.Write(changepassword) + + if !fedAuth { + _, err = w.Write(changepassword) + if err != nil { + return err + } + } + + // Write the feature extension record for federated authentication, if in use + if login.FedAuthLibrary == fedAuthLibrarySecurityToken { + w.WriteByte(featExtFEDAUTH) + binary.Write(w, binary.LittleEndian, uint32(1+4+len(fedauthtoken))) + w.WriteByte(login.FedAuthLibrary<<1 | login.FedAuthEcho) + binary.Write(w, binary.LittleEndian, uint32(len(fedauthtoken))) + w.Write(fedauthtoken) + } else if login.FedAuthLibrary == fedAuthLibraryADAL { + w.WriteByte(featExtFEDAUTH) + binary.Write(w, binary.LittleEndian, uint32(1+1)) + w.WriteByte(login.FedAuthLibrary<<1 | login.FedAuthEcho) + w.WriteByte(login.FedAuthADALWorkflow) + } + // Write the feature extension terminator if any feature extensions are written + if featureExtLength > 0 { + w.WriteByte(featExtTERMINATOR) + } + return w.FinishPacket() +} + +// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/827d9632-2957-4d54-b9ea-384530ae79d0 +func sendFedAuthInfo(w *tdsBuffer, login login) (err error) { + fedauthtoken := str2ucs2(login.FedAuthToken) + tokenlen := len(fedauthtoken) + datalen := 4 + tokenlen + len(login.FedAuthNonce) + + w.BeginPacket(packFedAuthToken, false) + err = binary.Write(w, binary.LittleEndian, uint32(datalen)) if err != nil { - return err + return + } + + err = binary.Write(w, binary.LittleEndian, uint32(tokenlen)) + if err != nil { + return } + + _, err = w.Write(fedauthtoken) + if err != nil { + return + } + + _, err = w.Write(login.FedAuthNonce) + if err != nil { + return + } + return w.FinishPacket() } @@ -751,6 +902,10 @@ initiate_connection: return nil, err } + if p.logFlags&logTraffic != 0 { + conn = newConnLogger(conn, "TCP", log) + } + toconn := newTimeoutConn(conn, p.conn_timeout) outbuf := newTdsBuffer(p.packetSize, toconn) @@ -778,6 +933,10 @@ initiate_connection: preloginMARS: {0}, // MARS disabled } + if p.fedAuthLibrary != fedAuthLibraryReserved { + fields[preloginFEDAUTHREQUIRED] = []byte{1} + } + err = writePrelogin(outbuf, fields) if err != nil { return nil, err @@ -788,6 +947,21 @@ initiate_connection: return nil, err } + // If the server returns the preloginFEDAUTHREQUIRED field, then federated authentication + // is supported. The actual value may be 0 or 1, where 0 means either SSPI or federated + // authentication is allowed, while 1 means only federated authentication is allowed. + var fedAuthEcho byte + if fedAuthSupport, ok := fields[preloginFEDAUTHREQUIRED]; ok { + if len(fedAuthSupport) == 1 { + // We need to be able to echo the value back to the server + fedAuthEcho = fedAuthSupport[0] + } else { + return nil, fmt.Errorf("Federated authentication flag length should be 1: is %d", len(fedAuthSupport)) + } + } else if p.fedAuthLibrary != fedAuthLibraryReserved { + return nil, fmt.Errorf("Federated authentication is not supported by the server") + } + encryptBytes, ok := fields[preloginENCRYPTION] if !ok { return nil, fmt.Errorf("Encrypt negotiation failed") @@ -811,6 +985,14 @@ initiate_connection: if p.trustServerCertificate { config.InsecureSkipVerify = true } + if p.tlsKeyLogFile != "" { + if w, err := os.OpenFile(p.tlsKeyLogFile, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0600); err == nil { + defer w.Close() + config.KeyLogWriter = w + } else { + return nil, fmt.Errorf("Cannot open TLS key log file %s: %v", p.tlsKeyLogFile, err) + } + } config.ServerName = p.hostInCertificate // fix for https://github.com/denisenkom/go-mssqldb/issues/166 // Go implementation of TLS payload size heuristic algorithm splits single TDS package to multiple TCP segments, @@ -823,7 +1005,11 @@ initiate_connection: tlsConn := tls.Client(&passthrough, &config) err = tlsConn.Handshake() passthrough.c = toconn - outbuf.transport = tlsConn + if sess.logFlags&logTraffic != 0 { + outbuf.transport = newConnLogger(tlsConn, "TLS", log) + } else { + outbuf.transport = tlsConn + } if err != nil { return nil, fmt.Errorf("TLS Handshake failed: %v", err) } @@ -835,23 +1021,54 @@ initiate_connection: } login := login{ - TDSVersion: verTDS74, - PacketSize: uint32(outbuf.PackageSize()), - Database: p.database, - OptionFlags2: fODBC, // to get unlimited TEXTSIZE - HostName: p.workstation, - ServerName: p.host, - AppName: p.appname, - TypeFlags: p.typeFlags, + TDSVersion: verTDS74, + PacketSize: uint32(outbuf.PackageSize()), + Database: p.database, + OptionFlags2: fODBC, // to get unlimited TEXTSIZE + HostName: p.workstation, + ServerName: p.host, + AppName: p.appname, + TypeFlags: p.typeFlags, + FedAuthLibrary: p.fedAuthLibrary, + FedAuthEcho: fedAuthEcho, + FedAuthADALWorkflow: p.fedAuthADALWorkflow, } auth, auth_ok := getAuth(p.user, p.password, p.serverSPN, p.workstation) - if auth_ok { + if login.FedAuthLibrary == fedAuthLibraryReserved && auth_ok { + if p.logFlags&logDebug != 0 { + log.Println("Starting SSPI login") + } login.SSPI, err = auth.InitialBytes() if err != nil { return nil, err } login.OptionFlags2 |= fIntSecurity defer auth.Free() + } else if login.FedAuthLibrary == fedAuthLibrarySecurityToken { + login.FedAuthToken, err = fedAuthGetAccessToken(ctx, "", p.aadTenantID, p, log) + if err != nil { + if p.logFlags&logDebug != 0 { + log.Printf("Failed to retrieve service principal token for federated authentication security token library: %v", err) + } + return nil, err + } + if p.logFlags&logDebug != 0 { + log.Println("Successfully obtained service principal token for federated authentication security token library") + } + } else if login.FedAuthLibrary == fedAuthLibraryADAL { + if login.FedAuthADALWorkflow == fedAuthADALWorkflowPassword { + if p.logFlags&logDebug != 0 { + log.Printf("Starting ADAL username/password workflow for user %s", p.user) + } + login.UserName = p.user + login.Password = p.password + } else if login.FedAuthADALWorkflow == fedAuthADALWorkflowMSI { + if p.logFlags&logDebug != 0 { + log.Println("Starting ADAL managed service identity (MSI) workflow") + } + } else { + return nil, fmt.Errorf("Unsupported ADAL workflow type 0x%02x", int(login.FedAuthADALWorkflow)) + } } else { login.UserName = p.user login.Password = p.password @@ -885,6 +1102,20 @@ initiate_connection: } sspi_msg = nil } + case fedAuthInfoStruct: + // For ADAL workflows this contains the STS URL and server SPN. + // If received outside of an ADAL workflow, ignore. + if login.FedAuthLibrary == fedAuthLibraryADAL { + login.FedAuthToken, err = fedAuthGetAccessToken(ctx, token.ServerSPN, token.STSURL, p, log) + if err != nil { + return nil, err + } + // Now need to send the token as a FEDINFO packet + err = sendFedAuthInfo(outbuf, login) + if err != nil { + return nil, err + } + } case loginAckStruct: success = true sess.loginAck = token diff --git a/token.go b/token.go index 1acac8a5..9e2a5c5e 100644 --- a/token.go +++ b/token.go @@ -6,31 +6,34 @@ import ( "errors" "fmt" "io" + "io/ioutil" "net" "strconv" "strings" ) -//go:generate stringer -type token +//go:generate go run golang.org/x/tools/cmd/stringer -type token type token byte // token ids const ( - tokenReturnStatus token = 121 // 0x79 - tokenColMetadata token = 129 // 0x81 - tokenOrder token = 169 // 0xA9 - tokenError token = 170 // 0xAA - tokenInfo token = 171 // 0xAB - tokenReturnValue token = 0xAC - tokenLoginAck token = 173 // 0xad - tokenRow token = 209 // 0xd1 - tokenNbcRow token = 210 // 0xd2 - tokenEnvChange token = 227 // 0xE3 - tokenSSPI token = 237 // 0xED - tokenDone token = 253 // 0xFD - tokenDoneProc token = 254 - tokenDoneInProc token = 255 + tokenReturnStatus token = 121 // 0x79 + tokenColMetadata token = 129 // 0x81 + tokenOrder token = 169 // 0xA9 + tokenError token = 170 // 0xAA + tokenInfo token = 171 // 0xAB + tokenReturnValue token = 0xAC + tokenLoginAck token = 173 // 0xad + tokenFeatureExtAck token = 174 // 0xAE + tokenRow token = 209 // 0xd1 + tokenNbcRow token = 210 // 0xd2 + tokenEnvChange token = 227 // 0xE3 + tokenSSPI token = 237 // 0xED + tokenFedAuthInfo token = 238 // 0xEE + tokenDone token = 253 // 0xFD + tokenDoneProc token = 254 + tokenDoneInProc token = 255 ) // done flags @@ -69,6 +72,11 @@ const ( envRouting = 20 ) +const ( + fedAuthInfoSTSURL = 0x01 + fedAuthInfoSPN = 0x02 +) + // COLMETADATA flags // https://msdn.microsoft.com/en-us/library/dd357363.aspx const ( @@ -424,6 +432,73 @@ func parseSSPIMsg(r *tdsBuffer) sspiMsg { return sspiMsg(buf) } +type fedAuthInfoStruct struct { + STSURL string + ServerSPN string +} + +type fedAuthInfoOpt struct { + fedAuthInfoID byte + dataLength, dataOffset uint32 +} + +func parseFedAuthInfo(r *tdsBuffer) fedAuthInfoStruct { + size := r.uint32() + + var STSURL, SPN string + var err error + + // Each fedAuthInfoOpt is one byte to indicate the info ID, + // then a four byte offset and a four byte length. + count := r.uint32() + offset := uint32(4) + opts := make([]fedAuthInfoOpt, count) + + for i := uint32(0); i < count; i++ { + fedAuthInfoID := r.byte() + dataLength := r.uint32() + dataOffset := r.uint32() + offset += 1 + 4 + 4 + + opts[i] = fedAuthInfoOpt{ + fedAuthInfoID: fedAuthInfoID, + dataLength: dataLength, + dataOffset: dataOffset, + } + } + + data := make([]byte, size-offset) + r.ReadFull(data) + + for i := uint32(0); i < count; i++ { + if opts[i].dataOffset < offset { + badStreamPanicf("Fed auth info opt stated data offset %d is before data begins in packet at %d", + opts[i].dataOffset, offset) + } else if opts[i].dataOffset+opts[i].dataLength > size { + badStreamPanicf("Fed auth info opt stated data length %d added to stated offset exceeds size of packet %d", + opts[i].dataOffset+opts[i].dataLength, size) + } else { + optData := data[opts[i].dataOffset-offset : opts[i].dataOffset-offset+opts[i].dataLength] + if opts[i].fedAuthInfoID == fedAuthInfoSTSURL { + STSURL, err = ucs22str(optData) + } else if opts[i].fedAuthInfoID == fedAuthInfoSPN { + SPN, err = ucs22str(optData) + } else { + err = fmt.Errorf("Unexpected fed auth info opt ID %d", int(opts[i].fedAuthInfoID)) + } + + if err != nil { + badStreamPanic(err) + } + } + } + + return fedAuthInfoStruct{ + STSURL: STSURL, + ServerSPN: SPN, + } +} + type loginAckStruct struct { Interface uint8 TDSVersion uint32 @@ -447,6 +522,45 @@ func parseLoginAck(r *tdsBuffer) loginAckStruct { return res } +type fedAuthAckStruct struct { + Nonce []byte + Signature []byte +} + +func parseFeatureExtAck(r *tdsBuffer) map[byte]interface{} { + ack := map[byte]interface{}{} + + for feature := r.byte(); feature != featExtTERMINATOR; feature = r.byte() { + length := r.uint32() + + switch feature { + case featExtFEDAUTH: + // In theory we need to know the federated authentication library to + // know how to parse, but the alternatives provide compatible structures. + fedAuthAck := fedAuthAckStruct{} + if length >= 32 { + fedAuthAck.Nonce = make([]byte, 0, 32) + r.ReadFull(fedAuthAck.Nonce) + length -= 32 + } + if length >= 32 { + fedAuthAck.Signature = make([]byte, 0, 32) + r.ReadFull(fedAuthAck.Signature) + length -= 32 + } + ack[feature] = fedAuthAck + + } + + // Skip unprocessed bytes + if length > 0 { + io.CopyN(ioutil.Discard, r, int64(length)) + } + } + + return ack +} + // http://msdn.microsoft.com/en-us/library/dd357363.aspx func parseColMetadata72(r *tdsBuffer) (columns []columnStruct) { count := r.uint16() @@ -571,12 +685,18 @@ func processSingleResponse(sess *tdsSession, ch chan tokenStruct, outs map[strin case tokenSSPI: ch <- parseSSPIMsg(sess.buf) return + case tokenFedAuthInfo: + ch <- parseFedAuthInfo(sess.buf) + return case tokenReturnStatus: returnStatus := parseReturnStatus(sess.buf) ch <- returnStatus case tokenLoginAck: loginAck := parseLoginAck(sess.buf) ch <- loginAck + case tokenFeatureExtAck: + featureExtAck := parseFeatureExtAck(sess.buf) + ch <- featureExtAck case tokenOrder: order := parseOrder(sess.buf) ch <- order diff --git a/token_string.go b/token_string.go index c075b23b..74389fdd 100644 --- a/token_string.go +++ b/token_string.go @@ -1,29 +1,46 @@ -// Code generated by "stringer -type token"; DO NOT EDIT +// Code generated by "stringer -type token"; DO NOT EDIT. package mssql -import "fmt" +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[tokenReturnStatus-121] + _ = x[tokenColMetadata-129] + _ = x[tokenOrder-169] + _ = x[tokenError-170] + _ = x[tokenInfo-171] + _ = x[tokenReturnValue-172] + _ = x[tokenLoginAck-173] + _ = x[tokenFeatureExtAck-174] + _ = x[tokenRow-209] + _ = x[tokenNbcRow-210] + _ = x[tokenEnvChange-227] + _ = x[tokenSSPI-237] + _ = x[tokenFedAuthInfo-238] + _ = x[tokenDone-253] + _ = x[tokenDoneProc-254] + _ = x[tokenDoneInProc-255] +} const ( _token_name_0 = "tokenReturnStatus" _token_name_1 = "tokenColMetadata" - _token_name_2 = "tokenOrdertokenErrortokenInfo" - _token_name_3 = "tokenLoginAck" - _token_name_4 = "tokenRowtokenNbcRow" - _token_name_5 = "tokenEnvChange" - _token_name_6 = "tokenSSPI" - _token_name_7 = "tokenDonetokenDoneProctokenDoneInProc" + _token_name_2 = "tokenOrdertokenErrortokenInfotokenReturnValuetokenLoginAcktokenFeatureExtAck" + _token_name_3 = "tokenRowtokenNbcRow" + _token_name_4 = "tokenEnvChange" + _token_name_5 = "tokenSSPItokenFedAuthInfo" + _token_name_6 = "tokenDonetokenDoneProctokenDoneInProc" ) var ( - _token_index_0 = [...]uint8{0, 17} - _token_index_1 = [...]uint8{0, 16} - _token_index_2 = [...]uint8{0, 10, 20, 29} - _token_index_3 = [...]uint8{0, 13} - _token_index_4 = [...]uint8{0, 8, 19} - _token_index_5 = [...]uint8{0, 14} - _token_index_6 = [...]uint8{0, 9} - _token_index_7 = [...]uint8{0, 9, 22, 37} + _token_index_2 = [...]uint8{0, 10, 20, 29, 45, 58, 76} + _token_index_3 = [...]uint8{0, 8, 19} + _token_index_5 = [...]uint8{0, 9, 25} + _token_index_6 = [...]uint8{0, 9, 22, 37} ) func (i token) String() string { @@ -32,22 +49,21 @@ func (i token) String() string { return _token_name_0 case i == 129: return _token_name_1 - case 169 <= i && i <= 171: + case 169 <= i && i <= 174: i -= 169 return _token_name_2[_token_index_2[i]:_token_index_2[i+1]] - case i == 173: - return _token_name_3 case 209 <= i && i <= 210: i -= 209 - return _token_name_4[_token_index_4[i]:_token_index_4[i+1]] + return _token_name_3[_token_index_3[i]:_token_index_3[i+1]] case i == 227: - return _token_name_5 - case i == 237: - return _token_name_6 + return _token_name_4 + case 237 <= i && i <= 238: + i -= 237 + return _token_name_5[_token_index_5[i]:_token_index_5[i+1]] case 253 <= i && i <= 255: i -= 253 - return _token_name_7[_token_index_7[i]:_token_index_7[i+1]] + return _token_name_6[_token_index_6[i]:_token_index_6[i+1]] default: - return fmt.Sprintf("token(%d)", i) + return "token(" + strconv.FormatInt(int64(i), 10) + ")" } } diff --git a/tvp_go19_db_test.go b/tvp_go19_db_test.go index 6cf42641..f4c2ddff 100644 --- a/tvp_go19_db_test.go +++ b/tvp_go19_db_test.go @@ -209,7 +209,8 @@ func TestTVP(t *testing.T) { bFalse := false floatValue64 := 0.123 floatValue32 := float32(-10.123) - timeNow := time.Now().UTC() + // Best datetime2 precision is 100 ns granularity + timeNow := time.Now().UTC().Truncate(100 * time.Nanosecond) param1 := []TvptableRow{ { PBinary: []byte("ccc"), @@ -462,7 +463,8 @@ func TestTVP_WithTag(t *testing.T) { bFalse := false floatValue64 := 0.123 floatValue32 := float32(-10.123) - timeNow := time.Now().UTC() + // Default (and maximum) datetime2 precision is 7 digits or 100ns + timeNow := time.Now().UTC().Truncate(100 * time.Nanosecond) param1 := []TvptableRowWithSkipTag{ { PBinary: []byte("ccc"), From 0ce9674c95f26d3e962377cbae4058384b4cf1d1 Mon Sep 17 00:00:00 2001 From: "Rose, William" Date: Wed, 22 Jan 2020 12:55:11 +1100 Subject: [PATCH 2/2] Adjust for Go 1.8-1.10 --- appveyor.yml | 2 ++ examples/azuread/azuread.go | 4 ++-- examples/tvp/tvp.go | 2 ++ 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/appveyor.yml b/appveyor.yml index c4d2bb06..51de3c26 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -30,6 +30,8 @@ install: - go version - go env - go get -u github.com/golang-sql/civil + - go get -u golang.org/x/crypto/pkcs12 + - go get -u github.com/Azure/go-autorest/autorest/adal build_script: - go build diff --git a/examples/azuread/azuread.go b/examples/azuread/azuread.go index 8bea6f35..cbcf9342 100644 --- a/examples/azuread/azuread.go +++ b/examples/azuread/azuread.go @@ -63,7 +63,7 @@ func createDatabaseIfNotExists() error { return err } - quoted := strings.ReplaceAll(*database, "]", "]]") + quoted := strings.Replace(*database, "]", "]]", -1) sql := "IF NOT EXISTS (SELECT 1 FROM sys.databases WHERE name = @p1)\n CREATE DATABASE [" + quoted + "] ( SERVICE_OBJECTIVE = 'S0' )" log.Printf("Exec: @p1 = '%s'\n%s\n", *database, sql) _, err = conn.Exec(sql, *database) @@ -98,7 +98,7 @@ func addExternalUserIfNotExists(user string) error { defer conn.Close() - quoted := strings.ReplaceAll(user, "]", "]]") + quoted := strings.Replace(user, "]", "]]", -1) sql := "IF NOT EXISTS (SELECT 1 FROM sys.database_principals WHERE name = @p1)\n CREATE USER [" + quoted + "] FROM EXTERNAL PROVIDER" log.Printf("Exec: @p1 = '%s'\n%s\n", user, sql) _, err = conn.Exec(sql, user) diff --git a/examples/tvp/tvp.go b/examples/tvp/tvp.go index a07bb652..eae614ef 100644 --- a/examples/tvp/tvp.go +++ b/examples/tvp/tvp.go @@ -1,3 +1,5 @@ +// +build go1.9 + package main import (