Skip to content

Commit

Permalink
Revert "[Monorepo] Enable proxy-authorization in admin client (#4189)"
Browse files Browse the repository at this point in the history
Signed-off-by: Andrew Dye <[email protected]>
  • Loading branch information
andrewwdye committed Oct 12, 2023
1 parent da77476 commit 94daddb
Show file tree
Hide file tree
Showing 9 changed files with 24 additions and 241 deletions.
85 changes: 8 additions & 77 deletions flyteidl/clients/go/admin/auth_interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package admin

import (
"context"
"errors"
"fmt"
"net/http"

Expand All @@ -17,12 +16,10 @@ import (
"google.golang.org/grpc"
)

const ProxyAuthorizationHeader = "proxy-authorization"

// MaterializeCredentials will attempt to build a TokenSource given the anonymously available information exposed by the server.
// Once established, it'll invoke PerRPCCredentialsFuture.Store() on perRPCCredentials to populate it with the appropriate values.
func MaterializeCredentials(ctx context.Context, cfg *Config, tokenCache cache.TokenCache, perRPCCredentials *PerRPCCredentialsFuture, proxyCredentialsFuture *PerRPCCredentialsFuture) error {
authMetadataClient, err := InitializeAuthMetadataClient(ctx, cfg, proxyCredentialsFuture)
func MaterializeCredentials(ctx context.Context, cfg *Config, tokenCache cache.TokenCache, perRPCCredentials *PerRPCCredentialsFuture) error {
authMetadataClient, err := InitializeAuthMetadataClient(ctx, cfg)
if err != nil {
return fmt.Errorf("failed to initialized Auth Metadata Client. Error: %w", err)
}
Expand Down Expand Up @@ -51,70 +48,19 @@ func MaterializeCredentials(ctx context.Context, cfg *Config, tokenCache cache.T
return nil
}

func GetProxyTokenSource(ctx context.Context, cfg *Config) (oauth2.TokenSource, error) {
tokenSourceProvider, err := NewExternalTokenSourceProvider(cfg.ProxyCommand)
if err != nil {
return nil, fmt.Errorf("failed to initialized proxy authorization token source provider. Err: %w", err)
}
proxyTokenSource, err := tokenSourceProvider.GetTokenSource(ctx)
if err != nil {
return nil, err
}
return proxyTokenSource, nil
}

func MaterializeProxyAuthCredentials(ctx context.Context, cfg *Config, proxyCredentialsFuture *PerRPCCredentialsFuture) error {
proxyTokenSource, err := GetProxyTokenSource(ctx, cfg)
if err != nil {
return err
}

wrappedTokenSource := NewCustomHeaderTokenSource(proxyTokenSource, cfg.UseInsecureConnection, ProxyAuthorizationHeader)
proxyCredentialsFuture.Store(wrappedTokenSource)

return nil
}

func shouldAttemptToAuthenticate(errorCode codes.Code) bool {
return errorCode == codes.Unauthenticated
}

type proxyAuthTransport struct {
transport http.RoundTripper
proxyCredentialsFuture *PerRPCCredentialsFuture
}

func (c *proxyAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
// check if the proxy credentials future is initialized
if !c.proxyCredentialsFuture.IsInitialized() {
return nil, errors.New("proxy credentials future is not initialized")
}

metadata, err := c.proxyCredentialsFuture.GetRequestMetadata(context.Background(), "")
if err != nil {
return nil, err
}
token := metadata[ProxyAuthorizationHeader]
req.Header.Add(ProxyAuthorizationHeader, token)
return c.transport.RoundTrip(req)
}

// Set up http client used in oauth2
func setHTTPClientContext(ctx context.Context, cfg *Config, proxyCredentialsFuture *PerRPCCredentialsFuture) context.Context {
func setHTTPClientContext(ctx context.Context, cfg *Config) context.Context {
httpClient := &http.Client{}
transport := &http.Transport{}

if len(cfg.HTTPProxyURL.String()) > 0 {
// create a transport that uses the proxy
transport.Proxy = http.ProxyURL(&cfg.HTTPProxyURL.URL)
}

if cfg.ProxyCommand != nil {
httpClient.Transport = &proxyAuthTransport{
transport: transport,
proxyCredentialsFuture: proxyCredentialsFuture,
transport := &http.Transport{
Proxy: http.ProxyURL(&cfg.HTTPProxyURL.URL),

Check warning on line 62 in flyteidl/clients/go/admin/auth_interceptor.go

View check run for this annotation

Codecov / codecov/patch

flyteidl/clients/go/admin/auth_interceptor.go#L61-L62

Added lines #L61 - L62 were not covered by tests
}
} else {
httpClient.Transport = transport
}

Expand All @@ -131,9 +77,9 @@ func setHTTPClientContext(ctx context.Context, cfg *Config, proxyCredentialsFutu
// more. It'll fail hard if it couldn't do so (i.e. it will no longer attempt to send an unauthenticated request). Once
// a token source has been created, it'll invoke the grpc pipeline again, this time the grpc.PerRPCCredentials should
// be able to find and acquire a valid AccessToken to annotate the request with.
func NewAuthInterceptor(cfg *Config, tokenCache cache.TokenCache, credentialsFuture *PerRPCCredentialsFuture, proxyCredentialsFuture *PerRPCCredentialsFuture) grpc.UnaryClientInterceptor {
func NewAuthInterceptor(cfg *Config, tokenCache cache.TokenCache, credentialsFuture *PerRPCCredentialsFuture) grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
ctx = setHTTPClientContext(ctx, cfg, proxyCredentialsFuture)
ctx = setHTTPClientContext(ctx, cfg)

err := invoker(ctx, method, req, reply, cc, opts...)
if err != nil {
Expand All @@ -143,7 +89,7 @@ func NewAuthInterceptor(cfg *Config, tokenCache cache.TokenCache, credentialsFut
// If the error we receive from executing the request expects
if shouldAttemptToAuthenticate(st.Code()) {
logger.Debugf(ctx, "Request failed due to [%v]. Attempting to establish an authenticated connection and trying again.", st.Code())
newErr := MaterializeCredentials(ctx, cfg, tokenCache, credentialsFuture, proxyCredentialsFuture)
newErr := MaterializeCredentials(ctx, cfg, tokenCache, credentialsFuture)
if newErr != nil {
return fmt.Errorf("authentication error! Original Error: %v, Auth Error: %w", err, newErr)
}
Expand All @@ -156,18 +102,3 @@ func NewAuthInterceptor(cfg *Config, tokenCache cache.TokenCache, credentialsFut
return err
}
}

func NewProxyAuthInterceptor(cfg *Config, proxyCredentialsFuture *PerRPCCredentialsFuture) grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {

err := invoker(ctx, method, req, reply, cc, opts...)
if err != nil {
newErr := MaterializeProxyAuthCredentials(ctx, cfg, proxyCredentialsFuture)
if newErr != nil {
return fmt.Errorf("proxy authorization error! Original Error: %v, Proxy Auth Error: %w", err, newErr)
}
return invoker(ctx, method, req, reply, cc, opts...)
}
return err
}
}
127 changes: 6 additions & 121 deletions flyteidl/clients/go/admin/auth_interceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import (

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"golang.org/x/oauth2"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
Expand Down Expand Up @@ -115,8 +114,7 @@ func newAuthMetadataServer(t testing.TB, port int, impl service.AuthMetadataServ
func Test_newAuthInterceptor(t *testing.T) {
t.Run("Other Error", func(t *testing.T) {
f := NewPerRPCCredentialsFuture()
p := NewPerRPCCredentialsFuture()
interceptor := NewAuthInterceptor(&Config{}, &mocks.TokenCache{}, f, p)
interceptor := NewAuthInterceptor(&Config{}, &mocks.TokenCache{}, f)
otherError := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
return status.New(codes.Canceled, "").Err()
}
Expand Down Expand Up @@ -148,12 +146,11 @@ func Test_newAuthInterceptor(t *testing.T) {
assert.NoError(t, err)

f := NewPerRPCCredentialsFuture()
p := NewPerRPCCredentialsFuture()
interceptor := NewAuthInterceptor(&Config{
Endpoint: config.URL{URL: *u},
UseInsecureConnection: true,
AuthType: AuthTypeClientSecret,
}, &mocks.TokenCache{}, f, p)
}, &mocks.TokenCache{}, f)
unauthenticated := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
return status.New(codes.Unauthenticated, "").Err()
}
Expand All @@ -180,13 +177,11 @@ func Test_newAuthInterceptor(t *testing.T) {
assert.NoError(t, err)

f := NewPerRPCCredentialsFuture()
p := NewPerRPCCredentialsFuture()

interceptor := NewAuthInterceptor(&Config{
Endpoint: config.URL{URL: *u},
UseInsecureConnection: true,
AuthType: AuthTypeClientSecret,
}, &mocks.TokenCache{}, f, p)
}, &mocks.TokenCache{}, f)
authenticated := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
return nil
}
Expand Down Expand Up @@ -221,13 +216,11 @@ func Test_newAuthInterceptor(t *testing.T) {
assert.NoError(t, err)

f := NewPerRPCCredentialsFuture()
p := NewPerRPCCredentialsFuture()

interceptor := NewAuthInterceptor(&Config{
Endpoint: config.URL{URL: *u},
UseInsecureConnection: true,
AuthType: AuthTypeClientSecret,
}, &mocks.TokenCache{}, f, p)
}, &mocks.TokenCache{}, f)
unauthenticated := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
return status.New(codes.Aborted, "").Err()
}
Expand All @@ -253,8 +246,6 @@ func TestMaterializeCredentials(t *testing.T) {
assert.NoError(t, err)

f := NewPerRPCCredentialsFuture()
p := NewPerRPCCredentialsFuture()

err = MaterializeCredentials(ctx, &Config{
Endpoint: config.URL{URL: *u},
UseInsecureConnection: true,
Expand All @@ -263,7 +254,7 @@ func TestMaterializeCredentials(t *testing.T) {
Scopes: []string{"all"},
Audience: "http://localhost:30081",
AuthorizationHeader: "authorization",
}, &mocks.TokenCache{}, f, p)
}, &mocks.TokenCache{}, f)
assert.NoError(t, err)
})
t.Run("Failed to fetch client metadata", func(t *testing.T) {
Expand All @@ -280,119 +271,13 @@ func TestMaterializeCredentials(t *testing.T) {
assert.NoError(t, err)

f := NewPerRPCCredentialsFuture()
p := NewPerRPCCredentialsFuture()

err = MaterializeCredentials(ctx, &Config{
Endpoint: config.URL{URL: *u},
UseInsecureConnection: true,
AuthType: AuthTypeClientSecret,
TokenURL: fmt.Sprintf("http://localhost:%d/api/v1/token", port),
Scopes: []string{"all"},
}, &mocks.TokenCache{}, f, p)
}, &mocks.TokenCache{}, f)
assert.EqualError(t, err, "failed to fetch client metadata. Error: rpc error: code = Unknown desc = expected err")
})
}

func TestNewProxyAuthInterceptor(t *testing.T) {
cfg := &Config{
ProxyCommand: []string{"echo", "test-token"},
}

p := NewPerRPCCredentialsFuture()

interceptor := NewProxyAuthInterceptor(cfg, p)

ctx := context.Background()
method := "/test.method"
req := "request"
reply := "reply"
cc := new(grpc.ClientConn)

errorInvoker := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
return errors.New("test error")
}

// Call should return an error and trigger the interceptor to materialize proxy auth credentials
err := interceptor(ctx, method, req, reply, cc, errorInvoker)
assert.Error(t, err)

// Check if proxyCredentialsFuture contains a proxy auth header token
creds, err := p.Get().GetRequestMetadata(ctx, "")
assert.True(t, p.IsInitialized())
assert.NoError(t, err)
assert.Equal(t, "Bearer test-token", creds[ProxyAuthorizationHeader])
}

type testRoundTripper struct {
RoundTripFunc func(req *http.Request) (*http.Response, error)
}

func (t *testRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return t.RoundTripFunc(req)
}

func TestSetHTTPClientContext(t *testing.T) {
ctx := context.Background()

t.Run("no proxy command and no proxy url", func(t *testing.T) {
cfg := &Config{}

newCtx := setHTTPClientContext(ctx, cfg, nil)

httpClient, ok := newCtx.Value(oauth2.HTTPClient).(*http.Client)
assert.True(t, ok)

transport, ok := httpClient.Transport.(*http.Transport)
assert.True(t, ok)
assert.Nil(t, transport.Proxy)
})

t.Run("proxy url", func(t *testing.T) {
cfg := &Config{
HTTPProxyURL: config.
URL{URL: url.URL{
Scheme: "http",
Host: "localhost:8080",
}},
}
newCtx := setHTTPClientContext(ctx, cfg, nil)

httpClient, ok := newCtx.Value(oauth2.HTTPClient).(*http.Client)
assert.True(t, ok)

transport, ok := httpClient.Transport.(*http.Transport)
assert.True(t, ok)
assert.NotNil(t, transport.Proxy)
})

t.Run("proxy command adds proxy-authorization header", func(t *testing.T) {
cfg := &Config{
ProxyCommand: []string{"echo", "test-token-http-client"},
}

p := NewPerRPCCredentialsFuture()
err := MaterializeProxyAuthCredentials(ctx, cfg, p)
assert.NoError(t, err)

newCtx := setHTTPClientContext(ctx, cfg, p)

httpClient, ok := newCtx.Value(oauth2.HTTPClient).(*http.Client)
assert.True(t, ok)

pat, ok := httpClient.Transport.(*proxyAuthTransport)
assert.True(t, ok)

testRoundTripper := &testRoundTripper{
RoundTripFunc: func(req *http.Request) (*http.Response, error) {
// Check if the ProxyAuthorizationHeader is correctly set
assert.Equal(t, "Bearer test-token-http-client", req.Header.Get(ProxyAuthorizationHeader))
return &http.Response{StatusCode: http.StatusOK}, nil
},
}
pat.transport = testRoundTripper

req, _ := http.NewRequest("GET", "http://example.com", nil)
_, err = httpClient.Do(req)
assert.NoError(t, err)
})
}
19 changes: 6 additions & 13 deletions flyteidl/clients/go/admin/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,21 +110,21 @@ func getAuthenticationDialOption(ctx context.Context, cfg *Config, tokenSourcePr
}

// InitializeAuthMetadataClient creates a new anonymously Auth Metadata Service client.
func InitializeAuthMetadataClient(ctx context.Context, cfg *Config, proxyCredentialsFuture *PerRPCCredentialsFuture) (client service.AuthMetadataServiceClient, err error) {
func InitializeAuthMetadataClient(ctx context.Context, cfg *Config) (client service.AuthMetadataServiceClient, err error) {
// Create an unauthenticated connection to fetch AuthMetadata
authMetadataConnection, err := NewAdminConnection(ctx, cfg, proxyCredentialsFuture)
authMetadataConnection, err := NewAdminConnection(ctx, cfg)
if err != nil {
return nil, fmt.Errorf("failed to initialized admin connection. Error: %w", err)
}

return service.NewAuthMetadataServiceClient(authMetadataConnection), nil
}

func NewAdminConnection(ctx context.Context, cfg *Config, proxyCredentialsFuture *PerRPCCredentialsFuture, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
func NewAdminConnection(ctx context.Context, cfg *Config, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
if opts == nil {
// Initialize opts list to the potential number of options we will add. Initialization optimizes memory
// allocation.
opts = make([]grpc.DialOption, 0, 7)
opts = make([]grpc.DialOption, 0, 5)
}

if cfg.UseInsecureConnection {
Expand Down Expand Up @@ -153,11 +153,6 @@ func NewAdminConnection(ctx context.Context, cfg *Config, proxyCredentialsFuture

opts = append(opts, GetAdditionalAdminClientConfigOptions(cfg)...)

if cfg.ProxyCommand != nil {
opts = append(opts, grpc.WithChainUnaryInterceptor(NewProxyAuthInterceptor(cfg, proxyCredentialsFuture)))
opts = append(opts, grpc.WithPerRPCCredentials(proxyCredentialsFuture))
}

return grpc.Dial(cfg.Endpoint.String(), opts...)
}

Expand All @@ -177,17 +172,15 @@ func InitializeAdminClient(ctx context.Context, cfg *Config, opts ...grpc.DialOp
// for the process. Note that if called with different cfg/dialoptions, it will not refresh the connection.
func initializeClients(ctx context.Context, cfg *Config, tokenCache cache.TokenCache, opts ...grpc.DialOption) (*Clientset, error) {
credentialsFuture := NewPerRPCCredentialsFuture()
proxyCredentialsFuture := NewPerRPCCredentialsFuture()

opts = append(opts,
grpc.WithChainUnaryInterceptor(NewAuthInterceptor(cfg, tokenCache, credentialsFuture, proxyCredentialsFuture)),
grpc.WithChainUnaryInterceptor(NewAuthInterceptor(cfg, tokenCache, credentialsFuture)),
grpc.WithPerRPCCredentials(credentialsFuture))

if cfg.DefaultServiceConfig != "" {
opts = append(opts, grpc.WithDefaultServiceConfig(cfg.DefaultServiceConfig))
}

adminConnection, err := NewAdminConnection(ctx, cfg, proxyCredentialsFuture, opts...)
adminConnection, err := NewAdminConnection(ctx, cfg, opts...)
if err != nil {
logger.Panicf(ctx, "failed to initialized Admin connection. Err: %s", err.Error())
}
Expand Down
Loading

0 comments on commit 94daddb

Please sign in to comment.