Skip to content

Commit

Permalink
feat: support custom request headers
Browse files Browse the repository at this point in the history
  • Loading branch information
Patrick Demers committed Oct 11, 2023
1 parent 0fcb544 commit 8fdd3e9
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 2 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,17 @@ db.Query("SELECT * FROM foobar WHERE id=?", 1, sql.Named("X-Trino-User", string(

The position of the X-Trino-User NamedArg is irrelevant and does not affect the query in any way.

#### Custom request headers

It is also possible to pass custom headers to the Trino request. You must pass a [NamedArg](https://godoc.org/database/sql#NamedArg) to the query parameters where the key begins with `X-Header-`.
The `X-Header-` prefix will be stripped from the header name.

Example (set `Authorization` header):

```go
db.Query("SELECT * FROM foobar WHERE id=?", 1, sql.Named("X-Header-Authorization", "Bearer xyz"))
```

### DSN (Data Source Name)

The Data Source Name is a URL with a mandatory username, and optional query string parameters that are supported by this driver, in the following format:
Expand Down
10 changes: 8 additions & 2 deletions trino/trino.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ const (
trinoAddedPrepareHeader = trinoHeaderPrefix + `Added-Prepare`
trinoDeallocatedPrepareHeader = trinoHeaderPrefix + `Deallocated-Prepare`

passthroughHeaderPrefix = `X-Header-`

KerberosEnabledConfig = "KerberosEnabled"
kerberosKeytabPathConfig = "KerberosKeytabPath"
kerberosPrincipalConfig = "KerberosPrincipal"
Expand Down Expand Up @@ -789,14 +791,18 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt
return nil, err
}

if strings.HasPrefix(arg.Name, trinoHeaderPrefix) {
isPassthroughHeader := strings.HasPrefix(arg.Name, passthroughHeaderPrefix)
if isPassthroughHeader || strings.HasPrefix(arg.Name, trinoHeaderPrefix) {
headerName := arg.Name
headerValue := arg.Value.(string)

if arg.Name == trinoUserHeader {
st.user = headerValue
} else if isPassthroughHeader {
headerName = strings.TrimPrefix(headerName, passthroughHeaderPrefix)
}

hs.Add(arg.Name, headerValue)
hs.Add(headerName, headerValue)
} else {
if hs.Get(preparedStatementHeader) == "" {
for _, v := range st.conn.httpHeaders.Values(preparedStatementHeader) {
Expand Down
25 changes: 25 additions & 0 deletions trino/trino_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,31 @@ func TestAuthFailure(t *testing.T) {
assert.NoError(t, db.Close())
}

func TestPassthroughHeader(t *testing.T) {
headerValue := "secret"
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
assert.Equal(t, headerValue, authHeader)
json.NewEncoder(w).Encode(&stmtResponse{
Error: stmtError{
ErrorName: "TEST",
},
})
}))

t.Cleanup(ts.Close)

db, err := sql.Open("trino", ts.URL)
require.NoError(t, err)

t.Cleanup(func() {
assert.NoError(t, db.Close())
})

_, err = db.Query("SELECT 1", sql.Named("X-Header-Authorization", headerValue))
assert.IsTypef(t, new(ErrQueryFailed), err, "unexpected error: %w", err)
}

func TestQueryForUsername(t *testing.T) {
if testing.Short() {
t.Skip("Skipping test in short mode.")
Expand Down

0 comments on commit 8fdd3e9

Please sign in to comment.