diff --git a/flytekit/clients/auth_helper.py b/flytekit/clients/auth_helper.py index 5b53856863..b1ee91d9ad 100644 --- a/flytekit/clients/auth_helper.py +++ b/flytekit/clients/auth_helper.py @@ -16,7 +16,7 @@ DeviceCodeAuthenticator, PKCEAuthenticator, ) -from flytekit.clients.grpc_utils.auth_interceptor import AuthUnaryInterceptor +from flytekit.clients.grpc_utils.auth_interceptor import AuthUnaryInterceptor, EagerAuthUnaryInterceptor from flytekit.clients.grpc_utils.default_metadata_interceptor import DefaultMetadataInterceptor from flytekit.clients.grpc_utils.wrap_exception_interceptor import RetryExceptionWrapperInterceptor from flytekit.configuration import AuthType, PlatformConfig @@ -124,7 +124,7 @@ def upgrade_channel_to_proxy_authenticated(cfg: PlatformConfig, in_channel: grpc """ if cfg.proxy_command: proxy_authenticator = get_proxy_authenticator(cfg) - return grpc.intercept_channel(in_channel, AuthUnaryInterceptor(proxy_authenticator)) + return grpc.intercept_channel(in_channel, EagerAuthUnaryInterceptor(proxy_authenticator)) else: return in_channel diff --git a/flytekit/clients/grpc_utils/auth_interceptor.py b/flytekit/clients/grpc_utils/auth_interceptor.py index e467801a77..3891a74539 100644 --- a/flytekit/clients/grpc_utils/auth_interceptor.py +++ b/flytekit/clients/grpc_utils/auth_interceptor.py @@ -78,3 +78,29 @@ def intercept_unary_stream(self, continuation, client_call_details, request): updated_call_details = self._call_details_with_auth_metadata(client_call_details) return continuation(updated_call_details, request) return c + + +class EagerAuthUnaryInterceptor(AuthUnaryInterceptor): + """ + This Interceptor can be used to automatically add Auth Metadata for every call - without trying without + authentication first. + """ + + def intercept_unary_unary( + self, + continuation: typing.Callable, + client_call_details: grpc.ClientCallDetails, + request: typing.Any, + ): + """ + Intercepts unary calls and proacively adds auth metadata. + """ + self._authenticator.refresh_credentials() + return super().intercept_unary_unary(continuation, client_call_details, request) + + def intercept_unary_stream(self, continuation, client_call_details, request): + """ + Handles stream calls and proacively adds auth metadata. + """ + self._authenticator.refresh_credentials() + return super().intercept_unary_stream(continuation, client_call_details, request) diff --git a/tests/flytekit/unit/clients/test_auth_helper.py b/tests/flytekit/unit/clients/test_auth_helper.py index 4baac2ebc5..4a45b3d8ba 100644 --- a/tests/flytekit/unit/clients/test_auth_helper.py +++ b/tests/flytekit/unit/clients/test_auth_helper.py @@ -22,7 +22,7 @@ upgrade_channel_to_proxy_authenticated, wrap_exceptions_channel, ) -from flytekit.clients.grpc_utils.auth_interceptor import AuthUnaryInterceptor +from flytekit.clients.grpc_utils.auth_interceptor import AuthUnaryInterceptor, EagerAuthUnaryInterceptor from flytekit.clients.grpc_utils.wrap_exception_interceptor import RetryExceptionWrapperInterceptor from flytekit.configuration import AuthType, PlatformConfig @@ -171,7 +171,7 @@ def test_upgrade_channel_to_proxy_auth(): ), ch, ) - assert isinstance(out_ch._interceptor, AuthUnaryInterceptor) + assert isinstance(out_ch._interceptor, EagerAuthUnaryInterceptor) assert isinstance(out_ch._interceptor._authenticator, CommandAuthenticator)