Skip to content

Commit

Permalink
[Monorepo] Enable proxy-authorization in admin client (#4189)
Browse files Browse the repository at this point in the history
* Add proxyCommand to client config

Signed-off-by: Fabio Grätz <[email protected]>

* Add proxy auth unary interceptor

Signed-off-by: Fabio Grätz <[email protected]>

* Use proxy auth in http client for oauth

Signed-off-by: Fabio Grätz <[email protected]>

* Cache tokens obtained from external commands

Signed-off-by: Fabio Grätz <[email protected]>

* Make tests pass

Signed-off-by: Fabio Grätz <[email protected]>

* Add tests for proxy auth interceptor

Signed-off-by: Fabio Grätz <[email protected]>

* Make work without 2nd token cache but instead with 2nd credentials future

Signed-off-by: Fabio Grätz <[email protected]>

* Adapt existing tests to not using a 2nd token cache but a 2nd credentials future

Signed-off-by: Fabio Grätz <[email protected]>

* Adapt new tests to not using a 2nd token cache but a 2nd credentials future

Signed-off-by: Fabio Grätz <[email protected]>

* Fix number of opts in NewAdminConnection

Signed-off-by: Fabio Grätz <[email protected]>

* Don't overwrite original error in NewProxyAuthInterceptor

Signed-off-by: Fabio Grätz <[email protected]>

* Improve error message when failing to create http client for oauth

Signed-off-by: Fabio Grätz <[email protected]>

* Actually don't return any error from setHTTPClientContext at all as before

Signed-off-by: Fabio Grätz <[email protected]>

* Don't require github.com/golang-jwt/jwt anymore

Signed-off-by: Fabio Grätz <[email protected]>

* Make tests pass again

Signed-off-by: Fabio Grätz <[email protected]>

* Lint

Signed-off-by: Fabio Grätz <[email protected]>

* make -C flyteidl generate

Signed-off-by: Eduardo Apolinario <[email protected]>

* Fix flytepropeller's getAdminClient

Signed-off-by: Fabio Grätz <[email protected]>

---------

Signed-off-by: Fabio Grätz <[email protected]>
Signed-off-by: Eduardo Apolinario <[email protected]>
Co-authored-by: Fabio Grätz <[email protected]>
Co-authored-by: Eduardo Apolinario <[email protected]>
  • Loading branch information
3 people authored Oct 11, 2023
1 parent 26228bd commit b9f6e8c
Show file tree
Hide file tree
Showing 10 changed files with 242 additions and 25 deletions.
85 changes: 77 additions & 8 deletions flyteidl/clients/go/admin/auth_interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package admin

import (
"context"
"errors"
"fmt"
"net/http"

Expand All @@ -16,10 +17,12 @@ import (
"google.golang.org/grpc"
)

const ProxyAuthorizationHeader = "proxy-authorization"

// MaterializeCredentials will attempt to build a TokenSource given the anonymously available information exposed by the server.
// Once established, it'll invoke PerRPCCredentialsFuture.Store() on perRPCCredentials to populate it with the appropriate values.
func MaterializeCredentials(ctx context.Context, cfg *Config, tokenCache cache.TokenCache, perRPCCredentials *PerRPCCredentialsFuture) error {
authMetadataClient, err := InitializeAuthMetadataClient(ctx, cfg)
func MaterializeCredentials(ctx context.Context, cfg *Config, tokenCache cache.TokenCache, perRPCCredentials *PerRPCCredentialsFuture, proxyCredentialsFuture *PerRPCCredentialsFuture) error {
authMetadataClient, err := InitializeAuthMetadataClient(ctx, cfg, proxyCredentialsFuture)
if err != nil {
return fmt.Errorf("failed to initialized Auth Metadata Client. Error: %w", err)
}
Expand Down Expand Up @@ -48,19 +51,70 @@ func MaterializeCredentials(ctx context.Context, cfg *Config, tokenCache cache.T
return nil
}

func GetProxyTokenSource(ctx context.Context, cfg *Config) (oauth2.TokenSource, error) {
tokenSourceProvider, err := NewExternalTokenSourceProvider(cfg.ProxyCommand)
if err != nil {
return nil, fmt.Errorf("failed to initialized proxy authorization token source provider. Err: %w", err)
}
proxyTokenSource, err := tokenSourceProvider.GetTokenSource(ctx)
if err != nil {
return nil, err
}
return proxyTokenSource, nil
}

func MaterializeProxyAuthCredentials(ctx context.Context, cfg *Config, proxyCredentialsFuture *PerRPCCredentialsFuture) error {
proxyTokenSource, err := GetProxyTokenSource(ctx, cfg)
if err != nil {
return err
}

wrappedTokenSource := NewCustomHeaderTokenSource(proxyTokenSource, cfg.UseInsecureConnection, ProxyAuthorizationHeader)
proxyCredentialsFuture.Store(wrappedTokenSource)

return nil
}

func shouldAttemptToAuthenticate(errorCode codes.Code) bool {
return errorCode == codes.Unauthenticated
}

type proxyAuthTransport struct {
transport http.RoundTripper
proxyCredentialsFuture *PerRPCCredentialsFuture
}

func (c *proxyAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
// check if the proxy credentials future is initialized
if !c.proxyCredentialsFuture.IsInitialized() {
return nil, errors.New("proxy credentials future is not initialized")
}

metadata, err := c.proxyCredentialsFuture.GetRequestMetadata(context.Background(), "")
if err != nil {
return nil, err
}
token := metadata[ProxyAuthorizationHeader]
req.Header.Add(ProxyAuthorizationHeader, token)
return c.transport.RoundTrip(req)
}

// Set up http client used in oauth2
func setHTTPClientContext(ctx context.Context, cfg *Config) context.Context {
func setHTTPClientContext(ctx context.Context, cfg *Config, proxyCredentialsFuture *PerRPCCredentialsFuture) context.Context {
httpClient := &http.Client{}
transport := &http.Transport{}

if len(cfg.HTTPProxyURL.String()) > 0 {
// create a transport that uses the proxy
transport := &http.Transport{
Proxy: http.ProxyURL(&cfg.HTTPProxyURL.URL),
transport.Proxy = http.ProxyURL(&cfg.HTTPProxyURL.URL)
}

if cfg.ProxyCommand != nil {
httpClient.Transport = &proxyAuthTransport{
transport: transport,
proxyCredentialsFuture: proxyCredentialsFuture,
}
} else {
httpClient.Transport = transport
}

Expand All @@ -77,9 +131,9 @@ func setHTTPClientContext(ctx context.Context, cfg *Config) context.Context {
// more. It'll fail hard if it couldn't do so (i.e. it will no longer attempt to send an unauthenticated request). Once
// a token source has been created, it'll invoke the grpc pipeline again, this time the grpc.PerRPCCredentials should
// be able to find and acquire a valid AccessToken to annotate the request with.
func NewAuthInterceptor(cfg *Config, tokenCache cache.TokenCache, credentialsFuture *PerRPCCredentialsFuture) grpc.UnaryClientInterceptor {
func NewAuthInterceptor(cfg *Config, tokenCache cache.TokenCache, credentialsFuture *PerRPCCredentialsFuture, proxyCredentialsFuture *PerRPCCredentialsFuture) grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
ctx = setHTTPClientContext(ctx, cfg)
ctx = setHTTPClientContext(ctx, cfg, proxyCredentialsFuture)

err := invoker(ctx, method, req, reply, cc, opts...)
if err != nil {
Expand All @@ -89,7 +143,7 @@ func NewAuthInterceptor(cfg *Config, tokenCache cache.TokenCache, credentialsFut
// If the error we receive from executing the request expects
if shouldAttemptToAuthenticate(st.Code()) {
logger.Debugf(ctx, "Request failed due to [%v]. Attempting to establish an authenticated connection and trying again.", st.Code())
newErr := MaterializeCredentials(ctx, cfg, tokenCache, credentialsFuture)
newErr := MaterializeCredentials(ctx, cfg, tokenCache, credentialsFuture, proxyCredentialsFuture)
if newErr != nil {
return fmt.Errorf("authentication error! Original Error: %v, Auth Error: %w", err, newErr)
}
Expand All @@ -102,3 +156,18 @@ func NewAuthInterceptor(cfg *Config, tokenCache cache.TokenCache, credentialsFut
return err
}
}

func NewProxyAuthInterceptor(cfg *Config, proxyCredentialsFuture *PerRPCCredentialsFuture) grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {

err := invoker(ctx, method, req, reply, cc, opts...)
if err != nil {
newErr := MaterializeProxyAuthCredentials(ctx, cfg, proxyCredentialsFuture)
if newErr != nil {
return fmt.Errorf("proxy authorization error! Original Error: %v, Proxy Auth Error: %w", err, newErr)
}
return invoker(ctx, method, req, reply, cc, opts...)
}
return err
}
}
127 changes: 121 additions & 6 deletions flyteidl/clients/go/admin/auth_interceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"golang.org/x/oauth2"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
Expand Down Expand Up @@ -114,7 +115,8 @@ func newAuthMetadataServer(t testing.TB, port int, impl service.AuthMetadataServ
func Test_newAuthInterceptor(t *testing.T) {
t.Run("Other Error", func(t *testing.T) {
f := NewPerRPCCredentialsFuture()
interceptor := NewAuthInterceptor(&Config{}, &mocks.TokenCache{}, f)
p := NewPerRPCCredentialsFuture()
interceptor := NewAuthInterceptor(&Config{}, &mocks.TokenCache{}, f, p)
otherError := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
return status.New(codes.Canceled, "").Err()
}
Expand Down Expand Up @@ -146,11 +148,12 @@ func Test_newAuthInterceptor(t *testing.T) {
assert.NoError(t, err)

f := NewPerRPCCredentialsFuture()
p := NewPerRPCCredentialsFuture()
interceptor := NewAuthInterceptor(&Config{
Endpoint: config.URL{URL: *u},
UseInsecureConnection: true,
AuthType: AuthTypeClientSecret,
}, &mocks.TokenCache{}, f)
}, &mocks.TokenCache{}, f, p)
unauthenticated := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
return status.New(codes.Unauthenticated, "").Err()
}
Expand All @@ -177,11 +180,13 @@ func Test_newAuthInterceptor(t *testing.T) {
assert.NoError(t, err)

f := NewPerRPCCredentialsFuture()
p := NewPerRPCCredentialsFuture()

interceptor := NewAuthInterceptor(&Config{
Endpoint: config.URL{URL: *u},
UseInsecureConnection: true,
AuthType: AuthTypeClientSecret,
}, &mocks.TokenCache{}, f)
}, &mocks.TokenCache{}, f, p)
authenticated := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
return nil
}
Expand Down Expand Up @@ -216,11 +221,13 @@ func Test_newAuthInterceptor(t *testing.T) {
assert.NoError(t, err)

f := NewPerRPCCredentialsFuture()
p := NewPerRPCCredentialsFuture()

interceptor := NewAuthInterceptor(&Config{
Endpoint: config.URL{URL: *u},
UseInsecureConnection: true,
AuthType: AuthTypeClientSecret,
}, &mocks.TokenCache{}, f)
}, &mocks.TokenCache{}, f, p)
unauthenticated := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
return status.New(codes.Aborted, "").Err()
}
Expand All @@ -246,6 +253,8 @@ func TestMaterializeCredentials(t *testing.T) {
assert.NoError(t, err)

f := NewPerRPCCredentialsFuture()
p := NewPerRPCCredentialsFuture()

err = MaterializeCredentials(ctx, &Config{
Endpoint: config.URL{URL: *u},
UseInsecureConnection: true,
Expand All @@ -254,7 +263,7 @@ func TestMaterializeCredentials(t *testing.T) {
Scopes: []string{"all"},
Audience: "http://localhost:30081",
AuthorizationHeader: "authorization",
}, &mocks.TokenCache{}, f)
}, &mocks.TokenCache{}, f, p)
assert.NoError(t, err)
})
t.Run("Failed to fetch client metadata", func(t *testing.T) {
Expand All @@ -271,13 +280,119 @@ func TestMaterializeCredentials(t *testing.T) {
assert.NoError(t, err)

f := NewPerRPCCredentialsFuture()
p := NewPerRPCCredentialsFuture()

err = MaterializeCredentials(ctx, &Config{
Endpoint: config.URL{URL: *u},
UseInsecureConnection: true,
AuthType: AuthTypeClientSecret,
TokenURL: fmt.Sprintf("http://localhost:%d/api/v1/token", port),
Scopes: []string{"all"},
}, &mocks.TokenCache{}, f)
}, &mocks.TokenCache{}, f, p)
assert.EqualError(t, err, "failed to fetch client metadata. Error: rpc error: code = Unknown desc = expected err")
})
}

func TestNewProxyAuthInterceptor(t *testing.T) {
cfg := &Config{
ProxyCommand: []string{"echo", "test-token"},
}

p := NewPerRPCCredentialsFuture()

interceptor := NewProxyAuthInterceptor(cfg, p)

ctx := context.Background()
method := "/test.method"
req := "request"
reply := "reply"
cc := new(grpc.ClientConn)

errorInvoker := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
return errors.New("test error")
}

// Call should return an error and trigger the interceptor to materialize proxy auth credentials
err := interceptor(ctx, method, req, reply, cc, errorInvoker)
assert.Error(t, err)

// Check if proxyCredentialsFuture contains a proxy auth header token
creds, err := p.Get().GetRequestMetadata(ctx, "")
assert.True(t, p.IsInitialized())
assert.NoError(t, err)
assert.Equal(t, "Bearer test-token", creds[ProxyAuthorizationHeader])
}

type testRoundTripper struct {
RoundTripFunc func(req *http.Request) (*http.Response, error)
}

func (t *testRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return t.RoundTripFunc(req)
}

func TestSetHTTPClientContext(t *testing.T) {
ctx := context.Background()

t.Run("no proxy command and no proxy url", func(t *testing.T) {
cfg := &Config{}

newCtx := setHTTPClientContext(ctx, cfg, nil)

httpClient, ok := newCtx.Value(oauth2.HTTPClient).(*http.Client)
assert.True(t, ok)

transport, ok := httpClient.Transport.(*http.Transport)
assert.True(t, ok)
assert.Nil(t, transport.Proxy)
})

t.Run("proxy url", func(t *testing.T) {
cfg := &Config{
HTTPProxyURL: config.
URL{URL: url.URL{
Scheme: "http",
Host: "localhost:8080",
}},
}
newCtx := setHTTPClientContext(ctx, cfg, nil)

httpClient, ok := newCtx.Value(oauth2.HTTPClient).(*http.Client)
assert.True(t, ok)

transport, ok := httpClient.Transport.(*http.Transport)
assert.True(t, ok)
assert.NotNil(t, transport.Proxy)
})

t.Run("proxy command adds proxy-authorization header", func(t *testing.T) {
cfg := &Config{
ProxyCommand: []string{"echo", "test-token-http-client"},
}

p := NewPerRPCCredentialsFuture()
err := MaterializeProxyAuthCredentials(ctx, cfg, p)
assert.NoError(t, err)

newCtx := setHTTPClientContext(ctx, cfg, p)

httpClient, ok := newCtx.Value(oauth2.HTTPClient).(*http.Client)
assert.True(t, ok)

pat, ok := httpClient.Transport.(*proxyAuthTransport)
assert.True(t, ok)

testRoundTripper := &testRoundTripper{
RoundTripFunc: func(req *http.Request) (*http.Response, error) {
// Check if the ProxyAuthorizationHeader is correctly set
assert.Equal(t, "Bearer test-token-http-client", req.Header.Get(ProxyAuthorizationHeader))
return &http.Response{StatusCode: http.StatusOK}, nil
},
}
pat.transport = testRoundTripper

req, _ := http.NewRequest("GET", "http://example.com", nil)
_, err = httpClient.Do(req)
assert.NoError(t, err)
})
}
19 changes: 13 additions & 6 deletions flyteidl/clients/go/admin/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,21 +110,21 @@ func getAuthenticationDialOption(ctx context.Context, cfg *Config, tokenSourcePr
}

// InitializeAuthMetadataClient creates a new anonymously Auth Metadata Service client.
func InitializeAuthMetadataClient(ctx context.Context, cfg *Config) (client service.AuthMetadataServiceClient, err error) {
func InitializeAuthMetadataClient(ctx context.Context, cfg *Config, proxyCredentialsFuture *PerRPCCredentialsFuture) (client service.AuthMetadataServiceClient, err error) {
// Create an unauthenticated connection to fetch AuthMetadata
authMetadataConnection, err := NewAdminConnection(ctx, cfg)
authMetadataConnection, err := NewAdminConnection(ctx, cfg, proxyCredentialsFuture)
if err != nil {
return nil, fmt.Errorf("failed to initialized admin connection. Error: %w", err)
}

return service.NewAuthMetadataServiceClient(authMetadataConnection), nil
}

func NewAdminConnection(ctx context.Context, cfg *Config, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
func NewAdminConnection(ctx context.Context, cfg *Config, proxyCredentialsFuture *PerRPCCredentialsFuture, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
if opts == nil {
// Initialize opts list to the potential number of options we will add. Initialization optimizes memory
// allocation.
opts = make([]grpc.DialOption, 0, 5)
opts = make([]grpc.DialOption, 0, 7)
}

if cfg.UseInsecureConnection {
Expand Down Expand Up @@ -153,6 +153,11 @@ func NewAdminConnection(ctx context.Context, cfg *Config, opts ...grpc.DialOptio

opts = append(opts, GetAdditionalAdminClientConfigOptions(cfg)...)

if cfg.ProxyCommand != nil {
opts = append(opts, grpc.WithChainUnaryInterceptor(NewProxyAuthInterceptor(cfg, proxyCredentialsFuture)))
opts = append(opts, grpc.WithPerRPCCredentials(proxyCredentialsFuture))
}

return grpc.Dial(cfg.Endpoint.String(), opts...)
}

Expand All @@ -172,15 +177,17 @@ func InitializeAdminClient(ctx context.Context, cfg *Config, opts ...grpc.DialOp
// for the process. Note that if called with different cfg/dialoptions, it will not refresh the connection.
func initializeClients(ctx context.Context, cfg *Config, tokenCache cache.TokenCache, opts ...grpc.DialOption) (*Clientset, error) {
credentialsFuture := NewPerRPCCredentialsFuture()
proxyCredentialsFuture := NewPerRPCCredentialsFuture()

opts = append(opts,
grpc.WithChainUnaryInterceptor(NewAuthInterceptor(cfg, tokenCache, credentialsFuture)),
grpc.WithChainUnaryInterceptor(NewAuthInterceptor(cfg, tokenCache, credentialsFuture, proxyCredentialsFuture)),
grpc.WithPerRPCCredentials(credentialsFuture))

if cfg.DefaultServiceConfig != "" {
opts = append(opts, grpc.WithDefaultServiceConfig(cfg.DefaultServiceConfig))
}

adminConnection, err := NewAdminConnection(ctx, cfg, opts...)
adminConnection, err := NewAdminConnection(ctx, cfg, proxyCredentialsFuture, opts...)
if err != nil {
logger.Panicf(ctx, "failed to initialized Admin connection. Err: %s", err.Error())
}
Expand Down
Loading

0 comments on commit b9f6e8c

Please sign in to comment.