From da220be87d2ad1eddc91c9e04075c8aecd415f74 Mon Sep 17 00:00:00 2001 From: Prafulla Mahindrakar Date: Wed, 13 Nov 2024 19:01:33 -0800 Subject: [PATCH] Upstream changes to fix token validity and utilizing inmemory creds source (#6001) * Auth/prevent lookup per call (#5686) (#555) Cherry-pick the following change to populate oauth metadata once on initialization using Sync.Do https://github.com/flyteorg/flyte/commit/ca04314c494b89216cfda3b3b1eb44eaff2e9da3 Tested locally using uctl-admin and fetched projects calling into admin which exercises the auth flow https://buildkite.com/unionai/org-staging-sync/builds/3541 Rollout to all canary and then prod tenants - [x] To be upstreamed to OSS *TODO: Link Linear issue(s) using [magic words](https://linear.app/docs/github#magic-words). `fixes` will move to merged status, while `ref` will only link the PR.* * [ ] Added tests * [ ] Ran a deploy dry run and shared the terraform plan * [ ] Added logging and metrics * [ ] Updated [dashboards](https://unionai.grafana.net/dashboards) and [alerts](https://unionai.grafana.net/alerting/list) * [ ] Updated documentation Signed-off-by: pmahindrakar-oss * [COR-1114] Fix token validity check logic to use exp field in access token (#330) * Add logs for token * add logs * Fixing the validity check logic for token * nit * nit * Adding in memory token source provider * nit * changed Valid method to log and ignore parseDateClaim error * nit * Fix unit tests * lint * fix unit tests Signed-off-by: pmahindrakar-oss * remove debug logs Signed-off-by: pmahindrakar-oss --------- Signed-off-by: pmahindrakar-oss --- flyteidl/clients/go/admin/auth_interceptor.go | 33 ++++++++---- .../clients/go/admin/auth_interceptor_test.go | 13 ++--- flyteidl/clients/go/admin/client_test.go | 23 +++----- .../clients/go/admin/token_source_provider.go | 46 +++++++++++++--- .../go/admin/token_source_provider_test.go | 25 ++++----- .../base_token_orchestrator.go | 4 +- .../base_token_orchestrator_test.go | 42 ++++----------- .../tokenorchestrator/testdata/token.json | 6 --- flyteidl/clients/go/admin/utils/test_utils.go | 24 +++++++++ .../clients/go/admin/utils/token_utils.go | 52 +++++++++++++++++++ 10 files changed, 177 insertions(+), 91 deletions(-) delete mode 100644 flyteidl/clients/go/admin/tokenorchestrator/testdata/token.json create mode 100644 flyteidl/clients/go/admin/utils/test_utils.go create mode 100644 flyteidl/clients/go/admin/utils/token_utils.go diff --git a/flyteidl/clients/go/admin/auth_interceptor.go b/flyteidl/clients/go/admin/auth_interceptor.go index 5d3d9fd92f..221dd98e9b 100644 --- a/flyteidl/clients/go/admin/auth_interceptor.go +++ b/flyteidl/clients/go/admin/auth_interceptor.go @@ -13,6 +13,7 @@ import ( "google.golang.org/grpc/status" "github.com/flyteorg/flyte/flyteidl/clients/go/admin/cache" + "github.com/flyteorg/flyte/flyteidl/clients/go/admin/utils" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service" "github.com/flyteorg/flyte/flytestdlib/logger" ) @@ -23,7 +24,6 @@ const ProxyAuthorizationHeader = "proxy-authorization" // Once established, it'll invoke PerRPCCredentialsFuture.Store() on perRPCCredentials to populate it with the appropriate values. func MaterializeCredentials(tokenSource oauth2.TokenSource, cfg *Config, authorizationMetadataKey string, perRPCCredentials *PerRPCCredentialsFuture) error { - _, err := tokenSource.Token() if err != nil { return fmt.Errorf("failed to issue token. Error: %w", err) @@ -35,6 +35,19 @@ func MaterializeCredentials(tokenSource oauth2.TokenSource, cfg *Config, authori return nil } +// MaterializeInMemoryCredentials initializes the perRPCCredentials with the token source containing in memory cached token. +// This path doesn't perform the token refresh and only build the cred source with cached token. +func MaterializeInMemoryCredentials(ctx context.Context, cfg *Config, tokenCache cache.TokenCache, + perRPCCredentials *PerRPCCredentialsFuture, authorizationMetadataKey string) error { + tokenSource, err := NewInMemoryTokenSourceProvider(tokenCache).GetTokenSource(ctx) + if err != nil { + return fmt.Errorf("failed to get token source. Error: %w", err) + } + wrappedTokenSource := NewCustomHeaderTokenSource(tokenSource, cfg.UseInsecureConnection, authorizationMetadataKey) + perRPCCredentials.Store(wrappedTokenSource) + return nil +} + func GetProxyTokenSource(ctx context.Context, cfg *Config) (oauth2.TokenSource, error) { tokenSourceProvider, err := NewExternalTokenSourceProvider(cfg.ProxyCommand) if err != nil { @@ -152,6 +165,7 @@ func (o *OauthMetadataProvider) GetOauthMetadata(cfg *Config, tokenCache cache.T if err != nil { logger.Errorf(context.Background(), "Failed to load token related config. Error: %v", err) } + logger.Debugf(context.Background(), "Successfully loaded token related metadata") }) if err != nil { return err @@ -176,22 +190,21 @@ func NewAuthInterceptor(cfg *Config, tokenCache cache.TokenCache, credentialsFut } return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { - ctx = setHTTPClientContext(ctx, cfg, proxyCredentialsFuture) - // If there is already a token in the cache (e.g. key-ring), we should use it immediately... t, _ := tokenCache.GetToken() if t != nil { + err := oauthMetadataProvider.GetOauthMetadata(cfg, tokenCache, proxyCredentialsFuture) if err != nil { return err } authorizationMetadataKey := oauthMetadataProvider.authorizationMetadataKey - tokenSource := oauthMetadataProvider.tokenSource - - err = MaterializeCredentials(tokenSource, cfg, authorizationMetadataKey, credentialsFuture) - if err != nil { - return fmt.Errorf("failed to materialize credentials. Error: %v", err) + if isValid := utils.Valid(t); isValid { + err := MaterializeInMemoryCredentials(ctx, cfg, tokenCache, credentialsFuture, authorizationMetadataKey) + if err != nil { + return fmt.Errorf("failed to materialize credentials. Error: %v", err) + } } } @@ -208,13 +221,11 @@ func NewAuthInterceptor(cfg *Config, tokenCache cache.TokenCache, credentialsFut } authorizationMetadataKey := oauthMetadataProvider.authorizationMetadataKey tokenSource := oauthMetadataProvider.tokenSource - err = func() error { if !tokenCache.TryLock() { tokenCache.CondWait() return nil } - defer tokenCache.Unlock() _, err := tokenCache.PurgeIfEquals(t) if err != nil && !errors.Is(err, cache.ErrNotFound) { @@ -237,6 +248,7 @@ func NewAuthInterceptor(cfg *Config, tokenCache cache.TokenCache, credentialsFut if err != nil { return err } + return invoker(ctx, method, req, reply, cc, opts...) } } @@ -257,6 +269,7 @@ func NewProxyAuthInterceptor(cfg *Config, proxyCredentialsFuture *PerRPCCredenti } return invoker(ctx, method, req, reply, cc, opts...) } + return err } } diff --git a/flyteidl/clients/go/admin/auth_interceptor_test.go b/flyteidl/clients/go/admin/auth_interceptor_test.go index b03171c825..0dee7428bc 100644 --- a/flyteidl/clients/go/admin/auth_interceptor_test.go +++ b/flyteidl/clients/go/admin/auth_interceptor_test.go @@ -2,17 +2,16 @@ package admin import ( "context" - "encoding/json" "errors" "fmt" "io" "net" "net/http" "net/url" - "os" "strings" "sync" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -24,7 +23,7 @@ import ( "github.com/flyteorg/flyte/flyteidl/clients/go/admin/cache/mocks" adminMocks "github.com/flyteorg/flyte/flyteidl/clients/go/admin/mocks" - + "github.com/flyteorg/flyte/flyteidl/clients/go/admin/utils" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service" "github.com/flyteorg/flyte/flytestdlib/config" "github.com/flyteorg/flyte/flytestdlib/logger" @@ -137,10 +136,7 @@ func newAuthMetadataServer(t testing.TB, grpcPort int, httpPort int, impl servic } func Test_newAuthInterceptor(t *testing.T) { - plan, _ := os.ReadFile("tokenorchestrator/testdata/token.json") - var tokenData oauth2.Token - err := json.Unmarshal(plan, &tokenData) - assert.NoError(t, err) + tokenData := utils.GenTokenWithCustomExpiry(t, time.Now().Add(20*time.Minute)) t.Run("Other Error", func(t *testing.T) { ctx := context.Background() httpPort := rand.IntnRange(10000, 60000) @@ -164,7 +160,8 @@ func Test_newAuthInterceptor(t *testing.T) { f := NewPerRPCCredentialsFuture() p := NewPerRPCCredentialsFuture() mockTokenCache := &mocks.TokenCache{} - mockTokenCache.OnGetTokenMatch().Return(&tokenData, nil) + + mockTokenCache.OnGetTokenMatch().Return(tokenData, nil) mockTokenCache.OnSaveTokenMatch(mock.Anything).Return(nil) interceptor := NewAuthInterceptor(&Config{ Endpoint: config.URL{URL: *u}, diff --git a/flyteidl/clients/go/admin/client_test.go b/flyteidl/clients/go/admin/client_test.go index 042a826692..e61f066c26 100644 --- a/flyteidl/clients/go/admin/client_test.go +++ b/flyteidl/clients/go/admin/client_test.go @@ -2,13 +2,10 @@ package admin import ( "context" - "encoding/json" "errors" "fmt" - "io/ioutil" "net/http" "net/url" - "os" "testing" "time" @@ -24,6 +21,7 @@ import ( "github.com/flyteorg/flyte/flyteidl/clients/go/admin/oauth" "github.com/flyteorg/flyte/flyteidl/clients/go/admin/pkce" "github.com/flyteorg/flyte/flyteidl/clients/go/admin/tokenorchestrator" + "github.com/flyteorg/flyte/flyteidl/clients/go/admin/utils" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service" "github.com/flyteorg/flyte/flytestdlib/config" "github.com/flyteorg/flyte/flytestdlib/logger" @@ -231,15 +229,11 @@ func TestGetAuthenticationDialOptionPkce(t *testing.T) { RedirectUri: "http://localhost:54545/callback", } http.DefaultServeMux = http.NewServeMux() - plan, _ := os.ReadFile("tokenorchestrator/testdata/token.json") - var tokenData oauth2.Token - err := json.Unmarshal(plan, &tokenData) - assert.NoError(t, err) - tokenData.Expiry = time.Now().Add(time.Minute) + tokenData := utils.GenTokenWithCustomExpiry(t, time.Now().Add(time.Minute)) t.Run("cache hit", func(t *testing.T) { mockTokenCache := new(cachemocks.TokenCache) mockAuthClient := new(mocks.AuthMetadataServiceClient) - mockTokenCache.OnGetTokenMatch().Return(&tokenData, nil) + mockTokenCache.OnGetTokenMatch().Return(tokenData, nil) mockTokenCache.OnSaveTokenMatch(mock.Anything).Return(nil) mockAuthClient.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(metadata, nil) mockAuthClient.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(clientMetatadata, nil) @@ -249,11 +243,11 @@ func TestGetAuthenticationDialOptionPkce(t *testing.T) { assert.NotNil(t, dialOption) assert.Nil(t, err) }) - tokenData.Expiry = time.Now().Add(-time.Minute) t.Run("cache miss auth failure", func(t *testing.T) { + tokenData = utils.GenTokenWithCustomExpiry(t, time.Now().Add(-time.Minute)) mockTokenCache := new(cachemocks.TokenCache) mockAuthClient := new(mocks.AuthMetadataServiceClient) - mockTokenCache.OnGetTokenMatch().Return(&tokenData, nil) + mockTokenCache.OnGetTokenMatch().Return(tokenData, nil) mockTokenCache.OnSaveTokenMatch(mock.Anything).Return(nil) mockTokenCache.On("Lock").Return() mockTokenCache.On("Unlock").Return() @@ -284,14 +278,11 @@ func Test_getPkceAuthTokenSource(t *testing.T) { mockAuthClient.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(clientMetatadata, nil) t.Run("cached token expired", func(t *testing.T) { - plan, _ := ioutil.ReadFile("tokenorchestrator/testdata/token.json") - var tokenData oauth2.Token - err := json.Unmarshal(plan, &tokenData) - assert.NoError(t, err) + tokenData := utils.GenTokenWithCustomExpiry(t, time.Now().Add(-time.Minute)) // populate the cache tokenCache := cache.NewTokenCacheInMemoryProvider() - assert.NoError(t, tokenCache.SaveToken(&tokenData)) + assert.NoError(t, tokenCache.SaveToken(tokenData)) baseOrchestrator := tokenorchestrator.BaseTokenOrchestrator{ ClientConfig: &oauth.Config{ diff --git a/flyteidl/clients/go/admin/token_source_provider.go b/flyteidl/clients/go/admin/token_source_provider.go index 4ecfa59215..2a51832da6 100644 --- a/flyteidl/clients/go/admin/token_source_provider.go +++ b/flyteidl/clients/go/admin/token_source_provider.go @@ -20,6 +20,7 @@ import ( "github.com/flyteorg/flyte/flyteidl/clients/go/admin/externalprocess" "github.com/flyteorg/flyte/flyteidl/clients/go/admin/pkce" "github.com/flyteorg/flyte/flyteidl/clients/go/admin/tokenorchestrator" + "github.com/flyteorg/flyte/flyteidl/clients/go/admin/utils" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service" "github.com/flyteorg/flyte/flytestdlib/logger" ) @@ -229,8 +230,14 @@ func (s *customTokenSource) Token() (*oauth2.Token, error) { s.mu.Lock() defer s.mu.Unlock() - if token, err := s.tokenCache.GetToken(); err == nil && token.Valid() { - return token, nil + token, err := s.tokenCache.GetToken() + if err != nil { + logger.Warnf(s.ctx, "failed to get token from cache: %v", err) + } else { + if isValid := utils.Valid(token); isValid { + logger.Infof(context.Background(), "retrieved token from cache with expiry %v", token.Expiry) + return token, nil + } } totalAttempts := s.cfg.MaxRetries + 1 // Add one for initial request attempt @@ -238,19 +245,21 @@ func (s *customTokenSource) Token() (*oauth2.Token, error) { Duration: s.cfg.PerRetryTimeout.Duration, Steps: totalAttempts, } - var token *oauth2.Token - err := retry.OnError(backoff, func(err error) bool { + + err = retry.OnError(backoff, func(err error) bool { return err != nil }, func() (err error) { token, err = s.new.Token() if err != nil { - logger.Infof(s.ctx, "failed to get token: %w", err) - return fmt.Errorf("failed to get token: %w", err) + logger.Infof(s.ctx, "failed to get new token: %w", err) + return fmt.Errorf("failed to get new token: %w", err) } + logger.Infof(context.Background(), "Fetched new token with expiry %v", token.Expiry) return nil }) if err != nil { - return nil, err + logger.Warnf(s.ctx, "failed to get new token: %v", err) + return nil, fmt.Errorf("failed to get new token: %w", err) } logger.Infof(s.ctx, "retrieved token with expiry %v", token.Expiry) @@ -262,6 +271,29 @@ func (s *customTokenSource) Token() (*oauth2.Token, error) { return token, nil } +type InMemoryTokenSourceProvider struct { + tokenCache cache.TokenCache +} + +func NewInMemoryTokenSourceProvider(tokenCache cache.TokenCache) TokenSourceProvider { + return InMemoryTokenSourceProvider{tokenCache: tokenCache} +} + +func (i InMemoryTokenSourceProvider) GetTokenSource(ctx context.Context) (oauth2.TokenSource, error) { + return GetInMemoryAuthTokenSource(ctx, i.tokenCache) +} + +// GetInMemoryAuthTokenSource Returns the token source with cached token +func GetInMemoryAuthTokenSource(ctx context.Context, tokenCache cache.TokenCache) (oauth2.TokenSource, error) { + authToken, err := tokenCache.GetToken() + if err != nil { + return nil, err + } + return &pkce.SimpleTokenSource{ + CachedToken: authToken, + }, nil +} + type DeviceFlowTokenSourceProvider struct { tokenOrchestrator deviceflow.TokenOrchestrator } diff --git a/flyteidl/clients/go/admin/token_source_provider_test.go b/flyteidl/clients/go/admin/token_source_provider_test.go index 43d0fdd928..941b697e75 100644 --- a/flyteidl/clients/go/admin/token_source_provider_test.go +++ b/flyteidl/clients/go/admin/token_source_provider_test.go @@ -13,6 +13,7 @@ import ( tokenCacheMocks "github.com/flyteorg/flyte/flyteidl/clients/go/admin/cache/mocks" adminMocks "github.com/flyteorg/flyte/flyteidl/clients/go/admin/mocks" + "github.com/flyteorg/flyte/flyteidl/clients/go/admin/utils" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service" ) @@ -88,9 +89,9 @@ func TestCustomTokenSource_Token(t *testing.T) { minuteAgo := time.Now().Add(-time.Minute) hourAhead := time.Now().Add(time.Hour) twoHourAhead := time.Now().Add(2 * time.Hour) - invalidToken := oauth2.Token{AccessToken: "foo", Expiry: minuteAgo} - validToken := oauth2.Token{AccessToken: "foo", Expiry: hourAhead} - newToken := oauth2.Token{AccessToken: "foo", Expiry: twoHourAhead} + invalidToken := utils.GenTokenWithCustomExpiry(t, minuteAgo) + validToken := utils.GenTokenWithCustomExpiry(t, hourAhead) + newToken := utils.GenTokenWithCustomExpiry(t, twoHourAhead) tests := []struct { name string @@ -101,24 +102,24 @@ func TestCustomTokenSource_Token(t *testing.T) { { name: "no cached token", token: nil, - newToken: &newToken, - expectedToken: &newToken, + newToken: newToken, + expectedToken: newToken, }, { name: "cached token valid", - token: &validToken, + token: validToken, newToken: nil, - expectedToken: &validToken, + expectedToken: validToken, }, { name: "cached token expired", - token: &invalidToken, - newToken: &newToken, - expectedToken: &newToken, + token: invalidToken, + newToken: newToken, + expectedToken: newToken, }, { name: "failed new token", - token: &invalidToken, + token: invalidToken, newToken: nil, expectedToken: nil, }, @@ -138,7 +139,7 @@ func TestCustomTokenSource_Token(t *testing.T) { assert.True(t, ok) mockSource := &adminMocks.TokenSource{} - if test.token != &validToken { + if test.token != validToken { if test.newToken != nil { mockSource.OnToken().Return(test.newToken, nil) } else { diff --git a/flyteidl/clients/go/admin/tokenorchestrator/base_token_orchestrator.go b/flyteidl/clients/go/admin/tokenorchestrator/base_token_orchestrator.go index 4fd3fa476c..441127ce07 100644 --- a/flyteidl/clients/go/admin/tokenorchestrator/base_token_orchestrator.go +++ b/flyteidl/clients/go/admin/tokenorchestrator/base_token_orchestrator.go @@ -8,6 +8,7 @@ import ( "github.com/flyteorg/flyte/flyteidl/clients/go/admin/cache" "github.com/flyteorg/flyte/flyteidl/clients/go/admin/oauth" + "github.com/flyteorg/flyte/flyteidl/clients/go/admin/utils" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service" "github.com/flyteorg/flyte/flytestdlib/config" "github.com/flyteorg/flyte/flytestdlib/logger" @@ -52,7 +53,8 @@ func (t BaseTokenOrchestrator) FetchTokenFromCacheOrRefreshIt(ctx context.Contex return nil, err } - if token.Valid() { + if isValid := utils.Valid(token); isValid { + logger.Infof(context.Background(), "retrieved token from cache with expiry %v", token.Expiry) return token, nil } diff --git a/flyteidl/clients/go/admin/tokenorchestrator/base_token_orchestrator_test.go b/flyteidl/clients/go/admin/tokenorchestrator/base_token_orchestrator_test.go index 0a1a9f4985..d7e5ca07b2 100644 --- a/flyteidl/clients/go/admin/tokenorchestrator/base_token_orchestrator_test.go +++ b/flyteidl/clients/go/admin/tokenorchestrator/base_token_orchestrator_test.go @@ -2,8 +2,6 @@ package tokenorchestrator import ( "context" - "encoding/json" - "os" "testing" "time" @@ -15,6 +13,7 @@ import ( cacheMocks "github.com/flyteorg/flyte/flyteidl/clients/go/admin/cache/mocks" "github.com/flyteorg/flyte/flyteidl/clients/go/admin/mocks" "github.com/flyteorg/flyte/flyteidl/clients/go/admin/oauth" + "github.com/flyteorg/flyte/flyteidl/clients/go/admin/utils" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service" "github.com/flyteorg/flyte/flytestdlib/config" ) @@ -32,12 +31,9 @@ func TestRefreshTheToken(t *testing.T) { TokenCache: tokenCacheProvider, } - plan, _ := os.ReadFile("testdata/token.json") - var tokenData oauth2.Token - err := json.Unmarshal(plan, &tokenData) - assert.Nil(t, err) t.Run("bad url in Config", func(t *testing.T) { - refreshedToken, err := orchestrator.RefreshToken(ctx, &tokenData) + tokenData := utils.GenTokenWithCustomExpiry(t, time.Now().Add(-20*time.Minute)) + refreshedToken, err := orchestrator.RefreshToken(ctx, tokenData) assert.Nil(t, refreshedToken) assert.NotNil(t, err) }) @@ -72,12 +68,8 @@ func TestFetchFromCache(t *testing.T) { tokenCacheProvider := cache.NewTokenCacheInMemoryProvider() orchestrator, err := NewBaseTokenOrchestrator(ctx, tokenCacheProvider, mockAuthClient) assert.NoError(t, err) - fileData, _ := os.ReadFile("testdata/token.json") - var tokenData oauth2.Token - err = json.Unmarshal(fileData, &tokenData) - assert.Nil(t, err) - tokenData.Expiry = time.Now().Add(20 * time.Minute) - err = tokenCacheProvider.SaveToken(&tokenData) + tokenData := utils.GenTokenWithCustomExpiry(t, time.Now().Add(20*time.Minute)) + err = tokenCacheProvider.SaveToken(tokenData) assert.Nil(t, err) cachedToken, err := orchestrator.FetchTokenFromCacheOrRefreshIt(ctx, config.Duration{Duration: 5 * time.Minute}) assert.Nil(t, err) @@ -89,12 +81,8 @@ func TestFetchFromCache(t *testing.T) { tokenCacheProvider := cache.NewTokenCacheInMemoryProvider() orchestrator, err := NewBaseTokenOrchestrator(ctx, tokenCacheProvider, mockAuthClient) assert.NoError(t, err) - fileData, _ := os.ReadFile("testdata/token.json") - var tokenData oauth2.Token - err = json.Unmarshal(fileData, &tokenData) - assert.Nil(t, err) - tokenData.Expiry = time.Now().Add(-20 * time.Minute) - err = tokenCacheProvider.SaveToken(&tokenData) + tokenData := utils.GenTokenWithCustomExpiry(t, time.Now().Add(-20*time.Minute)) + err = tokenCacheProvider.SaveToken(tokenData) assert.Nil(t, err) _, err = orchestrator.FetchTokenFromCacheOrRefreshIt(ctx, config.Duration{Duration: 5 * time.Minute}) assert.NotNil(t, err) @@ -104,12 +92,8 @@ func TestFetchFromCache(t *testing.T) { mockTokenCacheProvider := new(cacheMocks.TokenCache) orchestrator, err := NewBaseTokenOrchestrator(ctx, mockTokenCacheProvider, mockAuthClient) assert.NoError(t, err) - fileData, _ := os.ReadFile("testdata/token.json") - var tokenData oauth2.Token - err = json.Unmarshal(fileData, &tokenData) - assert.Nil(t, err) - tokenData.Expiry = time.Now().Add(20 * time.Minute) - mockTokenCacheProvider.OnGetTokenMatch(mock.Anything).Return(&tokenData, nil) + tokenData := utils.GenTokenWithCustomExpiry(t, time.Now().Add(20*time.Minute)) + mockTokenCacheProvider.OnGetTokenMatch(mock.Anything).Return(tokenData, nil) mockTokenCacheProvider.OnSaveTokenMatch(mock.Anything).Return(nil) assert.Nil(t, err) refreshedToken, err := orchestrator.FetchTokenFromCacheOrRefreshIt(ctx, config.Duration{Duration: 5 * time.Minute}) @@ -122,12 +106,8 @@ func TestFetchFromCache(t *testing.T) { mockTokenCacheProvider := new(cacheMocks.TokenCache) orchestrator, err := NewBaseTokenOrchestrator(ctx, mockTokenCacheProvider, mockAuthClient) assert.NoError(t, err) - fileData, _ := os.ReadFile("testdata/token.json") - var tokenData oauth2.Token - err = json.Unmarshal(fileData, &tokenData) - assert.Nil(t, err) - tokenData.Expiry = time.Now().Add(20 * time.Minute) - mockTokenCacheProvider.OnGetTokenMatch(mock.Anything).Return(&tokenData, nil) + tokenData := utils.GenTokenWithCustomExpiry(t, time.Now().Add(20*time.Minute)) + mockTokenCacheProvider.OnGetTokenMatch(mock.Anything).Return(tokenData, nil) assert.Nil(t, err) refreshedToken, err := orchestrator.FetchTokenFromCacheOrRefreshIt(ctx, config.Duration{Duration: 5 * time.Minute}) assert.Nil(t, err) diff --git a/flyteidl/clients/go/admin/tokenorchestrator/testdata/token.json b/flyteidl/clients/go/admin/tokenorchestrator/testdata/token.json deleted file mode 100644 index 721cecc5f6..0000000000 --- a/flyteidl/clients/go/admin/tokenorchestrator/testdata/token.json +++ /dev/null @@ -1,6 +0,0 @@ -{ - "access_token":"eyJhbGciOiJSUzI1NiIsImtleV9pZCI6IjlLZlNILXphZjRjY1dmTlNPbm91YmZUbnItVW5kMHVuY3ctWF9KNUJVdWciLCJ0eXAiOiJKV1QifQ.eyJhdWQiOlsiaHR0cHM6Ly9kZW1vLm51Y2x5ZGUuaW8iXSwiY2xpZW50X2lkIjoiZmx5dGVjdGwiLCJleHAiOjE2MTk1Mjk5MjcsImZvcm0iOnsiY29kZV9jaGFsbGVuZ2UiOiJ2bWNxazArZnJRS3Vvb2FMUHZwUDJCeUtod2VKR2VaeG1mdGtkMml0T042Tk13SVBQNWwySmNpWDd3NTdlaS9iVW1LTWhPSjJVUERnK0F5RXRaTG94SFJiMDl1cWRKSSIsImNvZGVfY2hhbGxlbmdlX21ldGhvZCI6IlN2WEgyeDh2UDUrSkJxQ0NjT2dCL0hNWjdLSmE3bkdLMDBaUVA0ekd4WGcifSwiaWF0IjoxNjE5NTAyNTM1LCJpc3MiOiJodHRwczovL2RlbW8ubnVjbHlkZS5pbyIsImp0aSI6IjQzMTM1ZWY2LTA5NjEtNGFlZC1hOTYxLWQyZGI1YWJmM2U1YyIsInNjcCI6WyJvZmZsaW5lIiwiYWxsIiwiYWNjZXNzX3Rva2VuIl0sInN1YiI6IjExNDUyNzgxNTMwNTEyODk3NDQ3MCIsInVzZXJfaW5mbyI6eyJmYW1pbHlfbmFtZSI6Ik1haGluZHJha2FyIiwiZ2l2ZW5fbmFtZSI6IlByYWZ1bGxhIiwibmFtZSI6IlByYWZ1bGxhIE1haGluZHJha2FyIiwicGljdHVyZSI6Imh0dHBzOi8vbGgzLmdvb2dsZXVzZXJjb250ZW50LmNvbS9hLS9BT2gxNEdqdVQxazgtOGE1dkJHT0lGMWFEZ2hZbUZ4OGhEOUtOaVI1am5adT1zOTYtYyIsInN1YmplY3QiOiIxMTQ1Mjc4MTUzMDUxMjg5NzQ0NzAifX0.ojbUOy2tF6HL8fIp1FJAQchU2MimlVMr3EGVPxMvYyahpW5YsWh6mz7qn4vpEnBuYZDf6cTaN50pJ8krlDX9RqtxF3iEfV2ZYHwyKMThI9sWh_kEBgGwUpyHyk98ZeqQX1uFOH3iwwhR-lPPUlpgdFGzKsxfxeFLOtu1y0V7BgA08KFqgYzl0lJqDYWBkJh_wUAv5g_r0NzSQCsMqb-B3Lno5ScMnlA3SZ_Hg-XdW8hnFIlrwJj4Cv47j3fcZxpqLbTNDXWWogmRbJb3YPlgn_LEnRAyZnFERHKMCE9vaBSTu-1Qstp-gRTORjyV7l3y680dEygQS-99KV3OSBlz6g", - "token_type":"bearer", - "refresh_token":"eyJhbGciOiJSUzI1NiIsImtleV9pZCI6IjlLZlNILXphZjRjY1dmTlNPbm91YmZUbnItVW5kMHVuY3ctWF9KNUJVdWciLCJ0eXAiOiJKV1QifQ.eyJhdWQiOlsiaHR0cHM6Ly9kZW1vLm51Y2x5ZGUuaW8iXSwiY2xpZW50X2lkIjoiZmx5dGVjdGwiLCJleHAiOjE2MTk1MzM1MjcsImZvcm0iOnsiY29kZV9jaGFsbGVuZ2UiOiJ2bWNxazArZnJRS3Vvb2FMUHZwUDJCeUtod2VKR2VaeG1mdGtkMml0T042Tk13SVBQNWwySmNpWDd3NTdlaS9iVW1LTWhPSjJVUERnK0F5RXRaTG94SFJiMDl1cWRKSSIsImNvZGVfY2hhbGxlbmdlX21ldGhvZCI6IlN2WEgyeDh2UDUrSkJxQ0NjT2dCL0hNWjdLSmE3bkdLMDBaUVA0ekd4WGcifSwiaWF0IjoxNjE5NTAyNTM1LCJpc3MiOiJodHRwczovL2RlbW8ubnVjbHlkZS5pbyIsImp0aSI6IjQzMTM1ZWY2LTA5NjEtNGFlZC1hOTYxLWQyZGI1YWJmM2U1YyIsInNjcCI6WyJvZmZsaW5lIiwiZi5hbGwiLCJhY2Nlc3NfdG9rZW4iXSwic3ViIjoiMTE0NTI3ODE1MzA1MTI4OTc0NDcwIiwidXNlcl9pbmZvIjp7ImZhbWlseV9uYW1lIjoiTWFoaW5kcmFrYXIiLCJnaXZlbl9uYW1lIjoiUHJhZnVsbGEiLCJuYW1lIjoiUHJhZnVsbGEgTWFoaW5kcmFrYXIiLCJwaWN0dXJlIjoiaHR0cHM6Ly9saDMuZ29vZ2xldXNlcmNvbnRlbnQuY29tL2EtL0FPaDE0R2p1VDFrOC04YTV2QkdPSUYxYURnaFltRng4aEQ5S05pUjVqblp1PXM5Ni1jIiwic3ViamVjdCI6IjExNDUyNzgxNTMwNTEyODk3NDQ3MCJ9fQ.YKom5-gE4e84rJJIfxcpbMzgjZT33UZ27UTa1y8pK2BAWaPjIZtwudwDHQ5Rd3m0mJJWhBp0j0e8h9DvzBUdpsnGMXSCYKP-ag9y9k5OW59FMm9RqIakWHtj6NPnxGO1jAsaNCYePj8knR7pBLCLCse2taDHUJ8RU1F0DeHNr2y-JupgG5y1vjBcb-9eD8OwOSTp686_hm7XoJlxiKx8dj2O7HPH7M2pAHA_0bVrKKj7Y_s3fRhkm_Aq6LRdA-IiTl9xJQxgVUreejls9-RR9mSTKj6A81-Isz3qAUttVVaA4OT5OdW879_yT7OSLw_QwpXzNZ7qOR7OIpmL_xZXig", - "expiry":"2021-04-27T19:55:26.658635+05:30" -} \ No newline at end of file diff --git a/flyteidl/clients/go/admin/utils/test_utils.go b/flyteidl/clients/go/admin/utils/test_utils.go new file mode 100644 index 0000000000..000bbbebba --- /dev/null +++ b/flyteidl/clients/go/admin/utils/test_utils.go @@ -0,0 +1,24 @@ +package utils + +import ( + "testing" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" + "golang.org/x/oauth2" +) + +func GenTokenWithCustomExpiry(t *testing.T, expiry time.Time) *oauth2.Token { + var signingKey = []byte("your_secret_key") + token := jwt.New(jwt.SigningMethodHS256) + claims := token.Claims.(jwt.MapClaims) + claims["exp"] = expiry.Unix() + tokenString, err := token.SignedString(signingKey) + assert.NoError(t, err) + return &oauth2.Token{ + AccessToken: tokenString, + Expiry: expiry, + TokenType: "bearer", + } +} diff --git a/flyteidl/clients/go/admin/utils/token_utils.go b/flyteidl/clients/go/admin/utils/token_utils.go new file mode 100644 index 0000000000..8c34cef00e --- /dev/null +++ b/flyteidl/clients/go/admin/utils/token_utils.go @@ -0,0 +1,52 @@ +package utils + +import ( + "context" + "fmt" + "time" + + "github.com/golang-jwt/jwt/v5" + "golang.org/x/oauth2" + + "github.com/flyteorg/flyte/flytestdlib/logger" +) + +// Ref : Taken from oAuth library implementation of expiry +// defaultExpiryDelta determines how earlier a token should be considered +// expired than its actual expiration time. It is used to avoid late +// expirations due to client-server time mismatches. +const defaultExpiryDelta = 10 * time.Second + +// Valid reports whether t is non-nil, has an AccessToken, and is not expired. +func Valid(t *oauth2.Token) bool { + if t == nil || t.AccessToken == "" { + return false + } + expiryDelta := defaultExpiryDelta + tokenExpiry, err := parseDateClaim(t.AccessToken) + if err != nil { + logger.Errorf(context.Background(), "parseDateClaim failed due to %v", err) + return false + } + logger.Debugf(context.Background(), "Token expiry : %v, Access token expiry : %v, Are the equal : %v", t.Expiry, tokenExpiry, tokenExpiry.Equal(t.Expiry)) + return !tokenExpiry.Add(-expiryDelta).Before(time.Now()) +} + +// parseDateClaim parses the JWT token string and extracts the expiration time +func parseDateClaim(tokenString string) (time.Time, error) { + // Parse the token + token, _, err := new(jwt.Parser).ParseUnverified(tokenString, jwt.MapClaims{}) + if err != nil { + return time.Time{}, err + } + + // Extract the claims + if claims, ok := token.Claims.(jwt.MapClaims); ok { + // Get the expiration time + if exp, ok := claims["exp"].(float64); ok { + return time.Unix(int64(exp), 0), nil + } + } + + return time.Time{}, fmt.Errorf("no expiration claim found in token") +}