Skip to content

Commit

Permalink
Upstream changes to fix token validity and utilizing inmemory creds s…
Browse files Browse the repository at this point in the history
…ource (#6001)

* Auth/prevent lookup per call (#5686) (#555)

Cherry-pick the following change to populate oauth metadata once on initialization using Sync.Do
ca04314

Tested locally using uctl-admin and fetched projects calling into admin which exercises the auth flow
https://buildkite.com/unionai/org-staging-sync/builds/3541

Rollout to all canary and then prod tenants

- [x] To be upstreamed to OSS

*TODO: Link Linear issue(s) using [magic words](https://linear.app/docs/github#magic-words). `fixes` will move to merged status, while `ref` will only link the PR.*

* [ ] Added tests
* [ ] Ran a deploy dry run and shared the terraform plan
* [ ] Added logging and metrics
* [ ] Updated [dashboards](https://unionai.grafana.net/dashboards) and [alerts](https://unionai.grafana.net/alerting/list)
* [ ] Updated documentation

Signed-off-by: pmahindrakar-oss <[email protected]>

* [COR-1114]  Fix token validity check logic to use exp field in access token (#330)

* Add logs for token

* add logs

* Fixing the validity check logic for token

* nit

* nit

* Adding in memory token source provider

* nit

* changed Valid method to log and ignore parseDateClaim error

* nit

* Fix unit tests

* lint

* fix unit tests

Signed-off-by: pmahindrakar-oss <[email protected]>

* remove debug logs

Signed-off-by: pmahindrakar-oss <[email protected]>

---------

Signed-off-by: pmahindrakar-oss <[email protected]>
  • Loading branch information
pmahindrakar-oss authored Nov 14, 2024
1 parent f27aff4 commit da220be
Show file tree
Hide file tree
Showing 10 changed files with 177 additions and 91 deletions.
33 changes: 23 additions & 10 deletions flyteidl/clients/go/admin/auth_interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"google.golang.org/grpc/status"

"github.com/flyteorg/flyte/flyteidl/clients/go/admin/cache"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/utils"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service"
"github.com/flyteorg/flyte/flytestdlib/logger"
)
Expand All @@ -23,7 +24,6 @@ const ProxyAuthorizationHeader = "proxy-authorization"
// Once established, it'll invoke PerRPCCredentialsFuture.Store() on perRPCCredentials to populate it with the appropriate values.
func MaterializeCredentials(tokenSource oauth2.TokenSource, cfg *Config, authorizationMetadataKey string,
perRPCCredentials *PerRPCCredentialsFuture) error {

_, err := tokenSource.Token()
if err != nil {
return fmt.Errorf("failed to issue token. Error: %w", err)
Expand All @@ -35,6 +35,19 @@ func MaterializeCredentials(tokenSource oauth2.TokenSource, cfg *Config, authori
return nil
}

// MaterializeInMemoryCredentials initializes the perRPCCredentials with the token source containing in memory cached token.
// This path doesn't perform the token refresh and only build the cred source with cached token.
func MaterializeInMemoryCredentials(ctx context.Context, cfg *Config, tokenCache cache.TokenCache,
perRPCCredentials *PerRPCCredentialsFuture, authorizationMetadataKey string) error {
tokenSource, err := NewInMemoryTokenSourceProvider(tokenCache).GetTokenSource(ctx)
if err != nil {
return fmt.Errorf("failed to get token source. Error: %w", err)
}
wrappedTokenSource := NewCustomHeaderTokenSource(tokenSource, cfg.UseInsecureConnection, authorizationMetadataKey)
perRPCCredentials.Store(wrappedTokenSource)
return nil
}

func GetProxyTokenSource(ctx context.Context, cfg *Config) (oauth2.TokenSource, error) {
tokenSourceProvider, err := NewExternalTokenSourceProvider(cfg.ProxyCommand)
if err != nil {
Expand Down Expand Up @@ -152,6 +165,7 @@ func (o *OauthMetadataProvider) GetOauthMetadata(cfg *Config, tokenCache cache.T
if err != nil {
logger.Errorf(context.Background(), "Failed to load token related config. Error: %v", err)
}
logger.Debugf(context.Background(), "Successfully loaded token related metadata")
})
if err != nil {
return err
Expand All @@ -176,22 +190,21 @@ func NewAuthInterceptor(cfg *Config, tokenCache cache.TokenCache, credentialsFut
}

return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {

ctx = setHTTPClientContext(ctx, cfg, proxyCredentialsFuture)

// If there is already a token in the cache (e.g. key-ring), we should use it immediately...
t, _ := tokenCache.GetToken()
if t != nil {

err := oauthMetadataProvider.GetOauthMetadata(cfg, tokenCache, proxyCredentialsFuture)
if err != nil {
return err
}
authorizationMetadataKey := oauthMetadataProvider.authorizationMetadataKey
tokenSource := oauthMetadataProvider.tokenSource

err = MaterializeCredentials(tokenSource, cfg, authorizationMetadataKey, credentialsFuture)
if err != nil {
return fmt.Errorf("failed to materialize credentials. Error: %v", err)
if isValid := utils.Valid(t); isValid {
err := MaterializeInMemoryCredentials(ctx, cfg, tokenCache, credentialsFuture, authorizationMetadataKey)
if err != nil {
return fmt.Errorf("failed to materialize credentials. Error: %v", err)
}
}
}

Expand All @@ -208,13 +221,11 @@ func NewAuthInterceptor(cfg *Config, tokenCache cache.TokenCache, credentialsFut
}
authorizationMetadataKey := oauthMetadataProvider.authorizationMetadataKey
tokenSource := oauthMetadataProvider.tokenSource

err = func() error {
if !tokenCache.TryLock() {
tokenCache.CondWait()
return nil
}

defer tokenCache.Unlock()
_, err := tokenCache.PurgeIfEquals(t)
if err != nil && !errors.Is(err, cache.ErrNotFound) {
Expand All @@ -237,6 +248,7 @@ func NewAuthInterceptor(cfg *Config, tokenCache cache.TokenCache, credentialsFut
if err != nil {
return err
}

return invoker(ctx, method, req, reply, cc, opts...)
}
}
Expand All @@ -257,6 +269,7 @@ func NewProxyAuthInterceptor(cfg *Config, proxyCredentialsFuture *PerRPCCredenti
}
return invoker(ctx, method, req, reply, cc, opts...)
}

return err
}
}
13 changes: 5 additions & 8 deletions flyteidl/clients/go/admin/auth_interceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,16 @@ package admin

import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"strings"
"sync"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
Expand All @@ -24,7 +23,7 @@ import (

"github.com/flyteorg/flyte/flyteidl/clients/go/admin/cache/mocks"
adminMocks "github.com/flyteorg/flyte/flyteidl/clients/go/admin/mocks"

"github.com/flyteorg/flyte/flyteidl/clients/go/admin/utils"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service"
"github.com/flyteorg/flyte/flytestdlib/config"
"github.com/flyteorg/flyte/flytestdlib/logger"
Expand Down Expand Up @@ -137,10 +136,7 @@ func newAuthMetadataServer(t testing.TB, grpcPort int, httpPort int, impl servic
}

func Test_newAuthInterceptor(t *testing.T) {
plan, _ := os.ReadFile("tokenorchestrator/testdata/token.json")
var tokenData oauth2.Token
err := json.Unmarshal(plan, &tokenData)
assert.NoError(t, err)
tokenData := utils.GenTokenWithCustomExpiry(t, time.Now().Add(20*time.Minute))
t.Run("Other Error", func(t *testing.T) {
ctx := context.Background()
httpPort := rand.IntnRange(10000, 60000)
Expand All @@ -164,7 +160,8 @@ func Test_newAuthInterceptor(t *testing.T) {
f := NewPerRPCCredentialsFuture()
p := NewPerRPCCredentialsFuture()
mockTokenCache := &mocks.TokenCache{}
mockTokenCache.OnGetTokenMatch().Return(&tokenData, nil)

mockTokenCache.OnGetTokenMatch().Return(tokenData, nil)
mockTokenCache.OnSaveTokenMatch(mock.Anything).Return(nil)
interceptor := NewAuthInterceptor(&Config{
Endpoint: config.URL{URL: *u},
Expand Down
23 changes: 7 additions & 16 deletions flyteidl/clients/go/admin/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@ package admin

import (
"context"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"os"
"testing"
"time"

Expand All @@ -24,6 +21,7 @@ import (
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/oauth"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/pkce"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/tokenorchestrator"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/utils"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service"
"github.com/flyteorg/flyte/flytestdlib/config"
"github.com/flyteorg/flyte/flytestdlib/logger"
Expand Down Expand Up @@ -231,15 +229,11 @@ func TestGetAuthenticationDialOptionPkce(t *testing.T) {
RedirectUri: "http://localhost:54545/callback",
}
http.DefaultServeMux = http.NewServeMux()
plan, _ := os.ReadFile("tokenorchestrator/testdata/token.json")
var tokenData oauth2.Token
err := json.Unmarshal(plan, &tokenData)
assert.NoError(t, err)
tokenData.Expiry = time.Now().Add(time.Minute)
tokenData := utils.GenTokenWithCustomExpiry(t, time.Now().Add(time.Minute))
t.Run("cache hit", func(t *testing.T) {
mockTokenCache := new(cachemocks.TokenCache)
mockAuthClient := new(mocks.AuthMetadataServiceClient)
mockTokenCache.OnGetTokenMatch().Return(&tokenData, nil)
mockTokenCache.OnGetTokenMatch().Return(tokenData, nil)
mockTokenCache.OnSaveTokenMatch(mock.Anything).Return(nil)
mockAuthClient.OnGetOAuth2MetadataMatch(mock.Anything, mock.Anything).Return(metadata, nil)
mockAuthClient.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(clientMetatadata, nil)
Expand All @@ -249,11 +243,11 @@ func TestGetAuthenticationDialOptionPkce(t *testing.T) {
assert.NotNil(t, dialOption)
assert.Nil(t, err)
})
tokenData.Expiry = time.Now().Add(-time.Minute)
t.Run("cache miss auth failure", func(t *testing.T) {
tokenData = utils.GenTokenWithCustomExpiry(t, time.Now().Add(-time.Minute))
mockTokenCache := new(cachemocks.TokenCache)
mockAuthClient := new(mocks.AuthMetadataServiceClient)
mockTokenCache.OnGetTokenMatch().Return(&tokenData, nil)
mockTokenCache.OnGetTokenMatch().Return(tokenData, nil)
mockTokenCache.OnSaveTokenMatch(mock.Anything).Return(nil)
mockTokenCache.On("Lock").Return()
mockTokenCache.On("Unlock").Return()
Expand Down Expand Up @@ -284,14 +278,11 @@ func Test_getPkceAuthTokenSource(t *testing.T) {
mockAuthClient.OnGetPublicClientConfigMatch(mock.Anything, mock.Anything).Return(clientMetatadata, nil)

t.Run("cached token expired", func(t *testing.T) {
plan, _ := ioutil.ReadFile("tokenorchestrator/testdata/token.json")
var tokenData oauth2.Token
err := json.Unmarshal(plan, &tokenData)
assert.NoError(t, err)
tokenData := utils.GenTokenWithCustomExpiry(t, time.Now().Add(-time.Minute))

// populate the cache
tokenCache := cache.NewTokenCacheInMemoryProvider()
assert.NoError(t, tokenCache.SaveToken(&tokenData))
assert.NoError(t, tokenCache.SaveToken(tokenData))

baseOrchestrator := tokenorchestrator.BaseTokenOrchestrator{
ClientConfig: &oauth.Config{
Expand Down
46 changes: 39 additions & 7 deletions flyteidl/clients/go/admin/token_source_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/externalprocess"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/pkce"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/tokenorchestrator"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/utils"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service"
"github.com/flyteorg/flyte/flytestdlib/logger"
)
Expand Down Expand Up @@ -229,28 +230,36 @@ func (s *customTokenSource) Token() (*oauth2.Token, error) {
s.mu.Lock()
defer s.mu.Unlock()

if token, err := s.tokenCache.GetToken(); err == nil && token.Valid() {
return token, nil
token, err := s.tokenCache.GetToken()
if err != nil {
logger.Warnf(s.ctx, "failed to get token from cache: %v", err)
} else {
if isValid := utils.Valid(token); isValid {
logger.Infof(context.Background(), "retrieved token from cache with expiry %v", token.Expiry)
return token, nil
}
}

totalAttempts := s.cfg.MaxRetries + 1 // Add one for initial request attempt
backoff := wait.Backoff{
Duration: s.cfg.PerRetryTimeout.Duration,
Steps: totalAttempts,
}
var token *oauth2.Token
err := retry.OnError(backoff, func(err error) bool {

err = retry.OnError(backoff, func(err error) bool {
return err != nil
}, func() (err error) {
token, err = s.new.Token()
if err != nil {
logger.Infof(s.ctx, "failed to get token: %w", err)
return fmt.Errorf("failed to get token: %w", err)
logger.Infof(s.ctx, "failed to get new token: %w", err)
return fmt.Errorf("failed to get new token: %w", err)
}
logger.Infof(context.Background(), "Fetched new token with expiry %v", token.Expiry)
return nil
})
if err != nil {
return nil, err
logger.Warnf(s.ctx, "failed to get new token: %v", err)
return nil, fmt.Errorf("failed to get new token: %w", err)
}
logger.Infof(s.ctx, "retrieved token with expiry %v", token.Expiry)

Expand All @@ -262,6 +271,29 @@ func (s *customTokenSource) Token() (*oauth2.Token, error) {
return token, nil
}

type InMemoryTokenSourceProvider struct {
tokenCache cache.TokenCache
}

func NewInMemoryTokenSourceProvider(tokenCache cache.TokenCache) TokenSourceProvider {
return InMemoryTokenSourceProvider{tokenCache: tokenCache}
}

func (i InMemoryTokenSourceProvider) GetTokenSource(ctx context.Context) (oauth2.TokenSource, error) {
return GetInMemoryAuthTokenSource(ctx, i.tokenCache)
}

// GetInMemoryAuthTokenSource Returns the token source with cached token
func GetInMemoryAuthTokenSource(ctx context.Context, tokenCache cache.TokenCache) (oauth2.TokenSource, error) {
authToken, err := tokenCache.GetToken()
if err != nil {
return nil, err
}
return &pkce.SimpleTokenSource{
CachedToken: authToken,
}, nil
}

type DeviceFlowTokenSourceProvider struct {
tokenOrchestrator deviceflow.TokenOrchestrator
}
Expand Down
25 changes: 13 additions & 12 deletions flyteidl/clients/go/admin/token_source_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (

tokenCacheMocks "github.com/flyteorg/flyte/flyteidl/clients/go/admin/cache/mocks"
adminMocks "github.com/flyteorg/flyte/flyteidl/clients/go/admin/mocks"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/utils"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service"
)

Expand Down Expand Up @@ -88,9 +89,9 @@ func TestCustomTokenSource_Token(t *testing.T) {
minuteAgo := time.Now().Add(-time.Minute)
hourAhead := time.Now().Add(time.Hour)
twoHourAhead := time.Now().Add(2 * time.Hour)
invalidToken := oauth2.Token{AccessToken: "foo", Expiry: minuteAgo}
validToken := oauth2.Token{AccessToken: "foo", Expiry: hourAhead}
newToken := oauth2.Token{AccessToken: "foo", Expiry: twoHourAhead}
invalidToken := utils.GenTokenWithCustomExpiry(t, minuteAgo)
validToken := utils.GenTokenWithCustomExpiry(t, hourAhead)
newToken := utils.GenTokenWithCustomExpiry(t, twoHourAhead)

tests := []struct {
name string
Expand All @@ -101,24 +102,24 @@ func TestCustomTokenSource_Token(t *testing.T) {
{
name: "no cached token",
token: nil,
newToken: &newToken,
expectedToken: &newToken,
newToken: newToken,
expectedToken: newToken,
},
{
name: "cached token valid",
token: &validToken,
token: validToken,
newToken: nil,
expectedToken: &validToken,
expectedToken: validToken,
},
{
name: "cached token expired",
token: &invalidToken,
newToken: &newToken,
expectedToken: &newToken,
token: invalidToken,
newToken: newToken,
expectedToken: newToken,
},
{
name: "failed new token",
token: &invalidToken,
token: invalidToken,
newToken: nil,
expectedToken: nil,
},
Expand All @@ -138,7 +139,7 @@ func TestCustomTokenSource_Token(t *testing.T) {
assert.True(t, ok)

mockSource := &adminMocks.TokenSource{}
if test.token != &validToken {
if test.token != validToken {
if test.newToken != nil {
mockSource.OnToken().Return(test.newToken, nil)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/flyteorg/flyte/flyteidl/clients/go/admin/cache"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/oauth"
"github.com/flyteorg/flyte/flyteidl/clients/go/admin/utils"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/service"
"github.com/flyteorg/flyte/flytestdlib/config"
"github.com/flyteorg/flyte/flytestdlib/logger"
Expand Down Expand Up @@ -52,7 +53,8 @@ func (t BaseTokenOrchestrator) FetchTokenFromCacheOrRefreshIt(ctx context.Contex
return nil, err
}

if token.Valid() {
if isValid := utils.Valid(token); isValid {
logger.Infof(context.Background(), "retrieved token from cache with expiry %v", token.Expiry)
return token, nil
}

Expand Down
Loading

0 comments on commit da220be

Please sign in to comment.