From af254df2a7e88930bd35b9d2b0980b32cbc47231 Mon Sep 17 00:00:00 2001 From: Hannes Herrmann Date: Sat, 24 Aug 2024 15:18:36 +0200 Subject: [PATCH] feat(cloneable-clients): make service account interceptor cloneable --- src/api/clients.rs | 6 +- src/api/interceptors.rs | 156 +++++++++++++++++++++++++++++++++------- 2 files changed, 133 insertions(+), 29 deletions(-) diff --git a/src/api/clients.rs b/src/api/clients.rs index a569df7..04c7ad3 100644 --- a/src/api/clients.rs +++ b/src/api/clients.rs @@ -57,7 +57,6 @@ pub struct ClientBuilderWithInterceptor { interceptor: T, } - impl ClientBuilder { /// Create a new client builder with the the provided endpoint. pub fn new(api_endpoint: &str) -> ClientBuilder { @@ -75,7 +74,10 @@ impl ClientBuilder { /// Clients with this authentication method will have the [`AccessTokenInterceptor`] /// attached. #[cfg(feature = "interceptors")] - pub fn with_access_token(self, access_token: &str) -> ClientBuilderWithInterceptor { + pub fn with_access_token( + self, + access_token: &str, + ) -> ClientBuilderWithInterceptor { ClientBuilderWithInterceptor { api_endpoint: self.api_endpoint, interceptor: AccessTokenInterceptor::new(access_token), diff --git a/src/api/interceptors.rs b/src/api/interceptors.rs index 813b5e1..52f93d9 100644 --- a/src/api/interceptors.rs +++ b/src/api/interceptors.rs @@ -3,6 +3,8 @@ //! interceptors is to authenticate the clients to ZITADEL with //! provided credentials. +use std::ops::Deref; +use std::sync::{Arc, RwLock}; use std::thread; use tokio::runtime::Builder; @@ -41,6 +43,7 @@ use crate::credentials::{AuthenticationOptions, ServiceAccount}; /// # Ok(()) /// # } /// ``` +#[derive(Clone)] pub struct AccessTokenInterceptor { access_token: String, } @@ -125,12 +128,21 @@ impl Interceptor for AccessTokenInterceptor { /// # Ok(()) /// # } /// ``` +#[derive(Clone)] pub struct ServiceAccountInterceptor { + inner: Arc, +} + +struct ServiceAccountInterceptorInner { audience: String, service_account: ServiceAccount, auth_options: AuthenticationOptions, - token: Option, - token_expiry: Option, + state: RwLock>, +} + +struct ServiceAccountInterceptorState { + token: String, + token_expiry: time::OffsetDateTime, } impl ServiceAccountInterceptor { @@ -144,11 +156,12 @@ impl ServiceAccountInterceptor { auth_options: Option, ) -> Self { Self { - audience: audience.to_string(), - service_account: service_account.clone(), - auth_options: auth_options.unwrap_or_default(), - token: None, - token_expiry: None, + inner: Arc::new(ServiceAccountInterceptorInner { + audience: audience.to_string(), + service_account: service_account.clone(), + auth_options: auth_options.unwrap_or_default(), + state: RwLock::new(None), + }), } } } @@ -157,25 +170,32 @@ impl Interceptor for ServiceAccountInterceptor { fn call(&mut self, mut request: tonic::Request<()>) -> Result, Status> { let meta = request.metadata_mut(); if !meta.contains_key("authorization") { - if let Some(token) = &self.token { - if let Some(expiry) = self.token_expiry { - if expiry > time::OffsetDateTime::now_utc() { - meta.insert( - "authorization", - format!("Bearer {}", token).parse().unwrap(), - ); - - return Ok(request); - } + // We unwrap the RWLock to propagate the error if any + // thread panics and the lock is poisoned + let state_read_guard = self.inner.state.read().unwrap(); + + if let Some(ServiceAccountInterceptorState { + token, + token_expiry, + }) = state_read_guard.deref() + { + if token_expiry > &time::OffsetDateTime::now_utc() { + meta.insert( + "authorization", + format!("Bearer {}", token).parse().unwrap(), + ); + + return Ok(request); } } + drop(state_read_guard); - let aud = self.audience.clone(); - let auth = self.auth_options.clone(); - let sa = self.service_account.clone(); + let aud = self.inner.audience.clone(); + let auth = self.inner.auth_options.clone(); + let sa = self.inner.service_account.clone(); let token = thread::spawn(move || { - let rt = Builder::new_multi_thread().enable_all().build().unwrap(); + let rt = Builder::new_current_thread().enable_all().build().unwrap(); rt.block_on(async { sa.authenticate_with_options(&aud, &auth) .await @@ -187,8 +207,14 @@ impl Interceptor for ServiceAccountInterceptor { .join() .map_err(|_| Status::internal("could not fetch token"))??; - self.token = Some(token.clone()); - self.token_expiry = Some(time::OffsetDateTime::now_utc() + time::Duration::minutes(59)); + // We unwrap the RWLock to propagate the error if any + // thread panics and the lock is poisoned + let mut state_write_guard = self.inner.state.write().unwrap(); + + *state_write_guard = Some(ServiceAccountInterceptorState { + token: token.clone(), + token_expiry: time::OffsetDateTime::now_utc() + time::Duration::minutes(59), + }); meta.insert( "authorization", @@ -288,6 +314,46 @@ mod tests { .is_empty()); } + #[test] + fn service_account_interceptor_can_be_cloned_and_shares_token_sync_context() { + let sa = ServiceAccount::load_from_json(SERVICE_ACCOUNT).unwrap(); + let mut interceptor = ServiceAccountInterceptor::new(ZITADEL_URL, &sa, None); + let mut second_interceptor = interceptor.clone(); + let request = Request::new(()); + let second_request = Request::new(()); + + assert!(request.metadata().is_empty()); + assert!(second_request.metadata().is_empty()); + + let request = interceptor.call(request).unwrap(); + let second_request = second_interceptor.call(second_request).unwrap(); + + assert_eq!( + request.metadata().get("authorization"), + second_request.metadata().get("authorization") + ); + } + + #[tokio::test] + async fn service_account_interceptor_can_be_cloned_and_shares_token_async_context() { + let sa = ServiceAccount::load_from_json(SERVICE_ACCOUNT).unwrap(); + let mut interceptor = ServiceAccountInterceptor::new(ZITADEL_URL, &sa, None); + let mut second_interceptor = interceptor.clone(); + let request = Request::new(()); + let second_request = Request::new(()); + + assert!(request.metadata().is_empty()); + assert!(second_request.metadata().is_empty()); + + let request = interceptor.call(request).unwrap(); + let second_request = second_interceptor.call(second_request).unwrap(); + + assert_eq!( + request.metadata().get("authorization"), + second_request.metadata().get("authorization") + ); + } + #[test] fn service_account_interceptor_ignore_existing_auth_metadata_sync_context() { let sa = ServiceAccount::load_from_json(SERVICE_ACCOUNT).unwrap(); @@ -333,10 +399,28 @@ mod tests { let sa = ServiceAccount::load_from_json(SERVICE_ACCOUNT).unwrap(); let mut interceptor = ServiceAccountInterceptor::new(ZITADEL_URL, &sa, None); interceptor.call(Request::new(())).unwrap(); - let token = interceptor.token.clone().unwrap(); + let token = interceptor + .inner + .state + .read() + .unwrap() + .as_ref() + .unwrap() + .token + .clone(); interceptor.call(Request::new(())).unwrap(); - assert_eq!(token, interceptor.token.unwrap()); + assert_eq!( + token, + interceptor + .inner + .state + .read() + .unwrap() + .as_ref() + .unwrap() + .token + ); } #[tokio::test] @@ -344,9 +428,27 @@ mod tests { let sa = ServiceAccount::load_from_json(SERVICE_ACCOUNT).unwrap(); let mut interceptor = ServiceAccountInterceptor::new(ZITADEL_URL, &sa, None); interceptor.call(Request::new(())).unwrap(); - let token = interceptor.token.clone().unwrap(); + let token = interceptor + .inner + .state + .read() + .unwrap() + .as_ref() + .unwrap() + .token + .clone(); interceptor.call(Request::new(())).unwrap(); - assert_eq!(token, interceptor.token.unwrap()); + assert_eq!( + token, + interceptor + .inner + .state + .read() + .unwrap() + .as_ref() + .unwrap() + .token + ); } }