diff --git a/README.md b/README.md index d953a1f..9c75dd9 100644 --- a/README.md +++ b/README.md @@ -92,6 +92,12 @@ Please refer to the [Coordinator JWT Authentication](https://trino.io/docs/current/security/jwt.html) for server-side configuration. +#### Authorization header forwarding +This driver supports forwarding authorization headers by adding a [NamedArg](https://godoc.org/database/sql#NamedArg) with the name `accessToken` (e.g., `accessToken=`) and setting the `ForwardAuthorizationHeader` field in the [Config](https://godoc.org/github.com/trinodb/trino-go-client/trino#Config) struct to `true`. + +When enabled, this configuration will override the `AccessToken` set in the `Config` struct. + + #### System access control and per-query user information It's possible to pass user information to Trino, different from the principal diff --git a/trino/trino.go b/trino/trino.go index 6b817f1..1453f03 100644 --- a/trino/trino.go +++ b/trino/trino.go @@ -132,16 +132,17 @@ const ( authorizationHeader = "Authorization" - kerberosEnabledConfig = "KerberosEnabled" - kerberosKeytabPathConfig = "KerberosKeytabPath" - kerberosPrincipalConfig = "KerberosPrincipal" - kerberosRealmConfig = "KerberosRealm" - kerberosConfigPathConfig = "KerberosConfigPath" - kerberosRemoteServiceNameConfig = "KerberosRemoteServiceName" - sslCertPathConfig = "SSLCertPath" - sslCertConfig = "SSLCert" - accessTokenConfig = "accessToken" - explicitPrepareConfig = "explicitPrepare" + kerberosEnabledConfig = "KerberosEnabled" + kerberosKeytabPathConfig = "KerberosKeytabPath" + kerberosPrincipalConfig = "KerberosPrincipal" + kerberosRealmConfig = "KerberosRealm" + kerberosConfigPathConfig = "KerberosConfigPath" + kerberosRemoteServiceNameConfig = "KerberosRemoteServiceName" + sslCertPathConfig = "SSLCertPath" + sslCertConfig = "SSLCert" + accessTokenConfig = "accessToken" + explicitPrepareConfig = "explicitPrepare" + forwardAuthorizationHeaderConfig = "forwardAuthorizationHeader" mapKeySeparator = ":" mapEntrySeparator = ";" @@ -168,22 +169,23 @@ var _ driver.Driver = &Driver{} // Config is a configuration that can be encoded to a DSN string. type Config struct { - ServerURI string // URI of the Trino server, e.g. http://user@localhost:8080 - Source string // Source of the connection (optional) - Catalog string // Catalog (optional) - Schema string // Schema (optional) - SessionProperties map[string]string // Session properties (optional) - ExtraCredentials map[string]string // Extra credentials (optional) - CustomClientName string // Custom client name (optional) - KerberosEnabled string // KerberosEnabled (optional, default is false) - KerberosKeytabPath string // Kerberos Keytab Path (optional) - KerberosPrincipal string // Kerberos Principal used to authenticate to KDC (optional) - KerberosRemoteServiceName string // Trino coordinator Kerberos service name (optional) - KerberosRealm string // The Kerberos Realm (optional) - KerberosConfigPath string // The krb5 config path (optional) - SSLCertPath string // The SSL cert path for TLS verification (optional) - SSLCert string // The SSL cert for TLS verification (optional) - AccessToken string // An access token (JWT) for authentication (optional) + ServerURI string // URI of the Trino server, e.g. http://user@localhost:8080 + Source string // Source of the connection (optional) + Catalog string // Catalog (optional) + Schema string // Schema (optional) + SessionProperties map[string]string // Session properties (optional) + ExtraCredentials map[string]string // Extra credentials (optional) + CustomClientName string // Custom client name (optional) + KerberosEnabled string // KerberosEnabled (optional, default is false) + KerberosKeytabPath string // Kerberos Keytab Path (optional) + KerberosPrincipal string // Kerberos Principal used to authenticate to KDC (optional) + KerberosRemoteServiceName string // Trino coordinator Kerberos service name (optional) + KerberosRealm string // The Kerberos Realm (optional) + KerberosConfigPath string // The krb5 config path (optional) + SSLCertPath string // The SSL cert path for TLS verification (optional) + SSLCert string // The SSL cert for TLS verification (optional) + AccessToken string // An access token (JWT) for authentication (optional) + ForwardAuthorizationHeader bool // Allow forwarding the `accessToken` named query parameter in the authorization header, overwriting the `AccessToken` option, if set (optional) } // FormatDSN returns a DSN string from the configuration. @@ -211,6 +213,10 @@ func (c *Config) FormatDSN() (string, error) { query := make(url.Values) query.Add("source", source) + if c.ForwardAuthorizationHeader { + query.Add(forwardAuthorizationHeaderConfig, "true") + } + KerberosEnabled, _ := strconv.ParseBool(c.KerberosEnabled) isSSL := serverURL.Scheme == "https" @@ -277,16 +283,17 @@ func (c *Config) FormatDSN() (string, error) { // Conn is a Trino connection. type Conn struct { - baseURL string - auth *url.Userinfo - httpClient http.Client - httpHeaders http.Header - kerberosClient *client.Client - kerberosEnabled bool - kerberosRemoteServiceName string - progressUpdater ProgressUpdater - progressUpdaterPeriod queryProgressCallbackPeriod - useExplicitPrepare bool + baseURL string + auth *url.Userinfo + httpClient http.Client + httpHeaders http.Header + kerberosEnabled bool + kerberosClient *client.Client + kerberosRemoteServiceName string + progressUpdater ProgressUpdater + progressUpdaterPeriod queryProgressCallbackPeriod + useExplicitPrepare bool + forwardAuthorizationHeader bool } var ( @@ -303,6 +310,9 @@ func newConn(dsn string) (*Conn, error) { query := serverURL.Query() kerberosEnabled, _ := strconv.ParseBool(query.Get(kerberosEnabledConfig)) + + forwardAuthorizationHeader, _ := strconv.ParseBool(query.Get(forwardAuthorizationHeaderConfig)) + useExplicitPrepare := true if query.Get(explicitPrepareConfig) != "" { useExplicitPrepare, _ = strconv.ParseBool(query.Get(explicitPrepareConfig)) @@ -359,13 +369,14 @@ func newConn(dsn string) (*Conn, error) { } c := &Conn{ - baseURL: serverURL.Scheme + "://" + serverURL.Host, - httpClient: *httpClient, - httpHeaders: make(http.Header), - kerberosClient: kerberosClient, - kerberosEnabled: kerberosEnabled, - kerberosRemoteServiceName: query.Get(kerberosRemoteServiceNameConfig), - useExplicitPrepare: useExplicitPrepare, + baseURL: serverURL.Scheme + "://" + serverURL.Host, + httpClient: *httpClient, + httpHeaders: make(http.Header), + kerberosClient: kerberosClient, + kerberosEnabled: kerberosEnabled, + kerberosRemoteServiceName: query.Get(kerberosRemoteServiceNameConfig), + useExplicitPrepare: useExplicitPrepare, + forwardAuthorizationHeader: forwardAuthorizationHeader, } var user string @@ -909,6 +920,12 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt continue } + if st.conn.forwardAuthorizationHeader && arg.Name == accessTokenConfig { + token := arg.Value.(string) + hs.Add(authorizationHeader, getAuthorization(token)) + continue + } + s, err := Serial(arg.Value) if err != nil { return nil, err diff --git a/trino/trino_test.go b/trino/trino_test.go index 42f08f5..cc4c181 100644 --- a/trino/trino_test.go +++ b/trino/trino_test.go @@ -1911,3 +1911,35 @@ func TestExec(t *testing.T) { _, err = db.Exec("DROP TABLE memory.default.test") require.NoError(t, err, "Failed executing DROP TABLE query") } + +func TestForwardAuthorizationHeaderConfig(t *testing.T) { + c := &Config{ + ServerURI: "https://foobar@localhost:8090", + ForwardAuthorizationHeader: true, + } + + dsn, err := c.FormatDSN() + require.NoError(t, err) + + want := "https://foobar@localhost:8090?forwardAuthorizationHeader=true&source=trino-go-client" + + assert.Equal(t, want, dsn) +} + +func TestForwardAuthorizationHeader(t *testing.T) { + var captureAuthHeader string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Capture the Authorization header for later inspection + captureAuthHeader = r.Header.Get("Authorization") + })) + + t.Cleanup(ts.Close) + + db, err := sql.Open("trino", ts.URL+"?forwardAuthorizationHeader=true") + require.NoError(t, err) + + _, _ = db.Query("SELECT 1", sql.Named("accessToken", string("token"))) // Ingore response to focus on header capture + require.Equal(t, "Bearer token", captureAuthHeader, "Authorization header is incorrect") + + assert.NoError(t, db.Close()) +}