Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upstream changes to fix token validity and utilizing inmemory creds source #6001

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 23 additions & 10 deletions flyteidl/clients/go/admin/auth_interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"google.golang.org/grpc/status"

"github.com/flyteorg/flyte/flyteidl/clients/go/admin/cache"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/utils"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service"
"github.com/flyteorg/flyte/flytestdlib/logger"
)
Expand All @@ -23,7 +24,6 @@
// Once established, it'll invoke PerRPCCredentialsFuture.Store() on perRPCCredentials to populate it with the appropriate values.
func MaterializeCredentials(tokenSource oauth2.TokenSource, cfg *Config, authorizationMetadataKey string,
perRPCCredentials *PerRPCCredentialsFuture) error {

_, err := tokenSource.Token()
if err != nil {
return fmt.Errorf("failed to issue token. Error: %w", err)
Expand All @@ -35,6 +35,19 @@
return nil
}

// MaterializeInMemoryCredentials initializes the perRPCCredentials with the token source containing in memory cached token.
// This path doesn't perform the token refresh and only build the cred source with cached token.
func MaterializeInMemoryCredentials(ctx context.Context, cfg *Config, tokenCache cache.TokenCache,
perRPCCredentials *PerRPCCredentialsFuture, authorizationMetadataKey string) error {
tokenSource, err := NewInMemoryTokenSourceProvider(tokenCache).GetTokenSource(ctx)
if err != nil {
return fmt.Errorf("failed to get token source. Error: %w", err)
}

Check warning on line 45 in flyteidl/clients/go/admin/auth_interceptor.go

View check run for this annotation

Codecov / codecov/patch

flyteidl/clients/go/admin/auth_interceptor.go#L44-L45

Added lines #L44 - L45 were not covered by tests
wrappedTokenSource := NewCustomHeaderTokenSource(tokenSource, cfg.UseInsecureConnection, authorizationMetadataKey)
perRPCCredentials.Store(wrappedTokenSource)
return nil
}

func GetProxyTokenSource(ctx context.Context, cfg *Config) (oauth2.TokenSource, error) {
tokenSourceProvider, err := NewExternalTokenSourceProvider(cfg.ProxyCommand)
if err != nil {
Expand Down Expand Up @@ -152,6 +165,7 @@
if err != nil {
logger.Errorf(context.Background(), "Failed to load token related config. Error: %v", err)
}
logger.Debugf(context.Background(), "Successfully loaded token related metadata")
})
if err != nil {
return err
Expand All @@ -176,22 +190,21 @@
}

return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {

ctx = setHTTPClientContext(ctx, cfg, proxyCredentialsFuture)

// If there is already a token in the cache (e.g. key-ring), we should use it immediately...
t, _ := tokenCache.GetToken()
if t != nil {

err := oauthMetadataProvider.GetOauthMetadata(cfg, tokenCache, proxyCredentialsFuture)
if err != nil {
return err
}
authorizationMetadataKey := oauthMetadataProvider.authorizationMetadataKey
tokenSource := oauthMetadataProvider.tokenSource

err = MaterializeCredentials(tokenSource, cfg, authorizationMetadataKey, credentialsFuture)
if err != nil {
return fmt.Errorf("failed to materialize credentials. Error: %v", err)
if isValid := utils.Valid(t); isValid {
err := MaterializeInMemoryCredentials(ctx, cfg, tokenCache, credentialsFuture, authorizationMetadataKey)
if err != nil {
return fmt.Errorf("failed to materialize credentials. Error: %v", err)
}

Check warning on line 207 in flyteidl/clients/go/admin/auth_interceptor.go

View check run for this annotation

Codecov / codecov/patch

flyteidl/clients/go/admin/auth_interceptor.go#L206-L207

Added lines #L206 - L207 were not covered by tests
}
}

Expand All @@ -208,13 +221,11 @@
}
authorizationMetadataKey := oauthMetadataProvider.authorizationMetadataKey
tokenSource := oauthMetadataProvider.tokenSource

err = func() error {
if !tokenCache.TryLock() {
tokenCache.CondWait()
return nil
}

defer tokenCache.Unlock()
_, err := tokenCache.PurgeIfEquals(t)
if err != nil && !errors.Is(err, cache.ErrNotFound) {
Expand All @@ -237,6 +248,7 @@
if err != nil {
return err
}

return invoker(ctx, method, req, reply, cc, opts...)
}
}
Expand All @@ -257,6 +269,7 @@
}
return invoker(ctx, method, req, reply, cc, opts...)
}

return err
}
}
13 changes: 5 additions & 8 deletions flyteidl/clients/go/admin/auth_interceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,16 @@ package admin

import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"strings"
"sync"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
Expand All @@ -24,7 +23,7 @@ import (

"github.com/flyteorg/flyte/flyteidl/clients/go/admin/cache/mocks"
adminMocks "github.com/flyteorg/flyte/flyteidl/clients/go/admin/mocks"

"github.com/flyteorg/flyte/flyteidl/clients/go/admin/utils"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service"
"github.com/flyteorg/flyte/flytestdlib/config"
"github.com/flyteorg/flyte/flytestdlib/logger"
Expand Down Expand Up @@ -137,10 +136,7 @@ func newAuthMetadataServer(t testing.TB, grpcPort int, httpPort int, impl servic
}

func Test_newAuthInterceptor(t *testing.T) {
plan, _ := os.ReadFile("tokenorchestrator/testdata/token.json")
var tokenData oauth2.Token
err := json.Unmarshal(plan, &tokenData)
assert.NoError(t, err)
tokenData := utils.GenTokenWithCustomExpiry(t, time.Now().Add(20*time.Minute))
t.Run("Other Error", func(t *testing.T) {
ctx := context.Background()
httpPort := rand.IntnRange(10000, 60000)
Expand All @@ -164,7 +160,8 @@ func Test_newAuthInterceptor(t *testing.T) {
f := NewPerRPCCredentialsFuture()
p := NewPerRPCCredentialsFuture()
mockTokenCache := &mocks.TokenCache{}
mockTokenCache.OnGetTokenMatch().Return(&tokenData, nil)

mockTokenCache.OnGetTokenMatch().Return(tokenData, nil)
mockTokenCache.OnSaveTokenMatch(mock.Anything).Return(nil)
interceptor := NewAuthInterceptor(&Config{
Endpoint: config.URL{URL: *u},
Expand Down
23 changes: 7 additions & 16 deletions flyteidl/clients/go/admin/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@ package admin

import (
"context"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"os"
"testing"
"time"

Expand All @@ -24,6 +21,7 @@ import (
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/oauth"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/pkce"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/tokenorchestrator"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/utils"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service"
"github.com/flyteorg/flyte/flytestdlib/config"
"github.com/flyteorg/flyte/flytestdlib/logger"
Expand Down Expand Up @@ -231,15 +229,11 @@ func TestGetAuthenticationDialOptionPkce(t *testing.T) {
RedirectUri: "http://localhost:54545/callback",
}
http.DefaultServeMux = http.NewServeMux()
plan, _ := os.ReadFile("tokenorchestrator/testdata/token.json")
var tokenData oauth2.Token
err := json.Unmarshal(plan, &tokenData)
assert.NoError(t, err)
tokenData.Expiry = time.Now().Add(time.Minute)
tokenData := utils.GenTokenWithCustomExpiry(t, time.Now().Add(time.Minute))
t.Run("cache hit", func(t *testing.T) {
mockTokenCache := new(cachemocks.TokenCache)
mockAuthClient := new(mocks.AuthMetadataServiceClient)
mockTokenCache.OnGetTokenMatch().Return(&tokenData, nil)
mockTokenCache.OnGetTokenMatch().Return(tokenData, nil)
mockTokenCache.OnSaveTokenMatch(mock.Anything).Return(nil)
mockAuthClient.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(metadata, nil)
mockAuthClient.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(clientMetatadata, nil)
Expand All @@ -249,11 +243,11 @@ func TestGetAuthenticationDialOptionPkce(t *testing.T) {
assert.NotNil(t, dialOption)
assert.Nil(t, err)
})
tokenData.Expiry = time.Now().Add(-time.Minute)
t.Run("cache miss auth failure", func(t *testing.T) {
tokenData = utils.GenTokenWithCustomExpiry(t, time.Now().Add(-time.Minute))
mockTokenCache := new(cachemocks.TokenCache)
mockAuthClient := new(mocks.AuthMetadataServiceClient)
mockTokenCache.OnGetTokenMatch().Return(&tokenData, nil)
mockTokenCache.OnGetTokenMatch().Return(tokenData, nil)
mockTokenCache.OnSaveTokenMatch(mock.Anything).Return(nil)
mockTokenCache.On("Lock").Return()
mockTokenCache.On("Unlock").Return()
Expand Down Expand Up @@ -284,14 +278,11 @@ func Test_getPkceAuthTokenSource(t *testing.T) {
mockAuthClient.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(clientMetatadata, nil)

t.Run("cached token expired", func(t *testing.T) {
plan, _ := ioutil.ReadFile("tokenorchestrator/testdata/token.json")
var tokenData oauth2.Token
err := json.Unmarshal(plan, &tokenData)
assert.NoError(t, err)
tokenData := utils.GenTokenWithCustomExpiry(t, time.Now().Add(-time.Minute))

// populate the cache
tokenCache := cache.NewTokenCacheInMemoryProvider()
assert.NoError(t, tokenCache.SaveToken(&tokenData))
assert.NoError(t, tokenCache.SaveToken(tokenData))

baseOrchestrator := tokenorchestrator.BaseTokenOrchestrator{
ClientConfig: &oauth.Config{
Expand Down
46 changes: 39 additions & 7 deletions flyteidl/clients/go/admin/token_source_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/externalprocess"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/pkce"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/tokenorchestrator"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/utils"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service"
"github.com/flyteorg/flyte/flytestdlib/logger"
)
Expand Down Expand Up @@ -229,28 +230,36 @@
s.mu.Lock()
defer s.mu.Unlock()

if token, err := s.tokenCache.GetToken(); err == nil && token.Valid() {
return token, nil
token, err := s.tokenCache.GetToken()
if err != nil {
logger.Warnf(s.ctx, "failed to get token from cache: %v", err)

Check warning on line 235 in flyteidl/clients/go/admin/token_source_provider.go

View check run for this annotation

Codecov / codecov/patch

flyteidl/clients/go/admin/token_source_provider.go#L235

Added line #L235 was not covered by tests
} else {
if isValid := utils.Valid(token); isValid {
logger.Infof(context.Background(), "retrieved token from cache with expiry %v", token.Expiry)
return token, nil
}
}

totalAttempts := s.cfg.MaxRetries + 1 // Add one for initial request attempt
backoff := wait.Backoff{
Duration: s.cfg.PerRetryTimeout.Duration,
Steps: totalAttempts,
}
var token *oauth2.Token
err := retry.OnError(backoff, func(err error) bool {

err = retry.OnError(backoff, func(err error) bool {
return err != nil
}, func() (err error) {
token, err = s.new.Token()
if err != nil {
logger.Infof(s.ctx, "failed to get token: %w", err)
return fmt.Errorf("failed to get token: %w", err)
logger.Infof(s.ctx, "failed to get new token: %w", err)
return fmt.Errorf("failed to get new token: %w", err)
}
logger.Infof(context.Background(), "Fetched new token with expiry %v", token.Expiry)
return nil
})
if err != nil {
return nil, err
logger.Warnf(s.ctx, "failed to get new token: %v", err)
return nil, fmt.Errorf("failed to get new token: %w", err)
}
logger.Infof(s.ctx, "retrieved token with expiry %v", token.Expiry)

Expand All @@ -262,6 +271,29 @@
return token, nil
}

type InMemoryTokenSourceProvider struct {
tokenCache cache.TokenCache
}

func NewInMemoryTokenSourceProvider(tokenCache cache.TokenCache) TokenSourceProvider {
return InMemoryTokenSourceProvider{tokenCache: tokenCache}
}

func (i InMemoryTokenSourceProvider) GetTokenSource(ctx context.Context) (oauth2.TokenSource, error) {
return GetInMemoryAuthTokenSource(ctx, i.tokenCache)
}

// GetInMemoryAuthTokenSource Returns the token source with cached token
func GetInMemoryAuthTokenSource(ctx context.Context, tokenCache cache.TokenCache) (oauth2.TokenSource, error) {
authToken, err := tokenCache.GetToken()
if err != nil {
return nil, err
}

Check warning on line 291 in flyteidl/clients/go/admin/token_source_provider.go

View check run for this annotation

Codecov / codecov/patch

flyteidl/clients/go/admin/token_source_provider.go#L290-L291

Added lines #L290 - L291 were not covered by tests
return &pkce.SimpleTokenSource{
CachedToken: authToken,
}, nil
}

type DeviceFlowTokenSourceProvider struct {
tokenOrchestrator deviceflow.TokenOrchestrator
}
Expand Down
25 changes: 13 additions & 12 deletions flyteidl/clients/go/admin/token_source_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (

tokenCacheMocks "github.com/flyteorg/flyte/flyteidl/clients/go/admin/cache/mocks"
adminMocks "github.com/flyteorg/flyte/flyteidl/clients/go/admin/mocks"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/utils"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service"
)

Expand Down Expand Up @@ -88,9 +89,9 @@ func TestCustomTokenSource_Token(t *testing.T) {
minuteAgo := time.Now().Add(-time.Minute)
hourAhead := time.Now().Add(time.Hour)
twoHourAhead := time.Now().Add(2 * time.Hour)
invalidToken := oauth2.Token{AccessToken: "foo", Expiry: minuteAgo}
validToken := oauth2.Token{AccessToken: "foo", Expiry: hourAhead}
newToken := oauth2.Token{AccessToken: "foo", Expiry: twoHourAhead}
invalidToken := utils.GenTokenWithCustomExpiry(t, minuteAgo)
validToken := utils.GenTokenWithCustomExpiry(t, hourAhead)
newToken := utils.GenTokenWithCustomExpiry(t, twoHourAhead)

tests := []struct {
name string
Expand All @@ -101,24 +102,24 @@ func TestCustomTokenSource_Token(t *testing.T) {
{
name: "no cached token",
token: nil,
newToken: &newToken,
expectedToken: &newToken,
newToken: newToken,
expectedToken: newToken,
},
{
name: "cached token valid",
token: &validToken,
token: validToken,
newToken: nil,
expectedToken: &validToken,
expectedToken: validToken,
},
{
name: "cached token expired",
token: &invalidToken,
newToken: &newToken,
expectedToken: &newToken,
token: invalidToken,
newToken: newToken,
expectedToken: newToken,
},
{
name: "failed new token",
token: &invalidToken,
token: invalidToken,
newToken: nil,
expectedToken: nil,
},
Expand All @@ -138,7 +139,7 @@ func TestCustomTokenSource_Token(t *testing.T) {
assert.True(t, ok)

mockSource := &adminMocks.TokenSource{}
if test.token != &validToken {
if test.token != validToken {
if test.newToken != nil {
mockSource.OnToken().Return(test.newToken, nil)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/flyteorg/flyte/flyteidl/clients/go/admin/cache"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/oauth"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/utils"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service"
"github.com/flyteorg/flyte/flytestdlib/config"
"github.com/flyteorg/flyte/flytestdlib/logger"
Expand Down Expand Up @@ -52,7 +53,8 @@ func (t BaseTokenOrchestrator) FetchTokenFromCacheOrRefreshIt(ctx context.Contex
return nil, err
}

if token.Valid() {
if isValid := utils.Valid(token); isValid {
logger.Infof(context.Background(), "retrieved token from cache with expiry %v", token.Expiry)
return token, nil
}

Expand Down
Loading
Loading