diff --git a/src/api/clients.rs b/src/api/clients.rs index 96bc4a2..a569df7 100644 --- a/src/api/clients.rs +++ b/src/api/clients.rs @@ -9,7 +9,6 @@ use custom_error::custom_error; use tonic::codegen::InterceptedService; use tonic::service::Interceptor; use tonic::transport::{Channel, ClientTlsConfig, Endpoint}; -use tonic::{Request, Status}; #[cfg(feature = "interceptors")] use crate::api::interceptors::{AccessTokenInterceptor, ServiceAccountInterceptor}; @@ -45,66 +44,30 @@ custom_error! { TlsInitializationError = "could not setup tls connection", } -#[cfg(feature = "interceptors")] -enum AuthType { - None, - AccessToken(String), - ServiceAccount(ServiceAccount, Option), -} - -/// Helper [Interceptor] that allows chaining of multiple interceptors. -/// This is used to help return the same type in all builder methods like -/// [ClientBuilder::build_management_client]. Otherwise, each interceptor -/// would create its own return type. With this interceptor, the return type -/// stays the same and is not dependent on the authentication type used. -/// The builder can always return `Client>`. -pub struct ChainedInterceptor { - interceptors: Vec>, -} - -impl ChainedInterceptor { - pub(crate) fn new() -> Self { - Self { - interceptors: Vec::new(), - } - } - - #[cfg(feature = "interceptors")] - pub(crate) fn add_interceptor(mut self, interceptor: Box) -> Self { - self.interceptors.push(interceptor); - self - } -} - -impl Interceptor for ChainedInterceptor { - fn call(&mut self, request: Request<()>) -> Result, Status> { - let mut request = request; - for interceptor in &mut self.interceptors { - request = interceptor.call(request)?; - } - Ok(request) - } -} - /// A builder to create configured gRPC clients for ZITADEL API access. /// The builder accepts the api endpoint and (depending on activated features) /// an authentication method. pub struct ClientBuilder { api_endpoint: String, - #[cfg(feature = "interceptors")] - auth_type: AuthType, } +#[cfg(feature = "interceptors")] +pub struct ClientBuilderWithInterceptor { + api_endpoint: String, + interceptor: T, +} + + impl ClientBuilder { /// Create a new client builder with the the provided endpoint. - pub fn new(api_endpoint: &str) -> Self { - Self { + pub fn new(api_endpoint: &str) -> ClientBuilder { + ClientBuilder { api_endpoint: api_endpoint.to_string(), - #[cfg(feature = "interceptors")] - auth_type: AuthType::None, } } +} +impl ClientBuilder { /// Configure the client builder to use a provided access token. /// This can be a pre-fetched token from ZITADEL or some other form /// of a valid access token like a personal access token (PAT). @@ -112,9 +75,11 @@ impl ClientBuilder { /// Clients with this authentication method will have the [`AccessTokenInterceptor`] /// attached. #[cfg(feature = "interceptors")] - pub fn with_access_token(mut self, access_token: &str) -> Self { - self.auth_type = AuthType::AccessToken(access_token.to_string()); - self + pub fn with_access_token(self, access_token: &str) -> ClientBuilderWithInterceptor { + ClientBuilderWithInterceptor { + api_endpoint: self.api_endpoint, + interceptor: AccessTokenInterceptor::new(access_token), + } } /// Configure the client builder to use a [`ServiceAccount`][crate::credentials::ServiceAccount]. @@ -125,14 +90,151 @@ impl ClientBuilder { /// that fetches an access token from ZITADEL and renewes it when it expires. #[cfg(feature = "interceptors")] pub fn with_service_account( - mut self, + self, service_account: &ServiceAccount, auth_options: Option, - ) -> Self { - self.auth_type = AuthType::ServiceAccount(service_account.clone(), auth_options); - self + ) -> ClientBuilderWithInterceptor { + let interceptor = ServiceAccountInterceptor::new( + &self.api_endpoint, + service_account, + auth_options.clone(), + ); + ClientBuilderWithInterceptor { + api_endpoint: self.api_endpoint, + interceptor, + } } +} +impl ClientBuilder { + /// Create a new [`AdminServiceClient`]. + /// + /// ### Errors + /// + /// This function returns a [`ClientError`] if the provided API endpoint + /// cannot be parsed into a valid URL or if the connection to the endpoint + /// is not possible. + #[cfg(feature = "api-admin-v1")] + pub async fn build_admin_client(self) -> Result, Box> { + let channel = crate::api::clients::get_channel(&self.api_endpoint).await?; + Ok(AdminServiceClient::new(channel)) + } + + /// Create a new [`AuthServiceClient`]. + /// + /// ### Errors + /// + /// This function returns a [`ClientError`] if the provided API endpoint + /// cannot be parsed into a valid URL or if the connection to the endpoint + /// is not possible. + #[cfg(feature = "api-auth-v1")] + pub async fn build_auth_client(self) -> Result, Box> { + let channel = crate::api::clients::get_channel(&self.api_endpoint).await?; + Ok(AuthServiceClient::new(channel)) + } + + /// Create a new [`ManagementServiceClient`]. + /// + /// ### Errors + /// + /// This function returns a [`ClientError`] if the provided API endpoint + /// cannot be parsed into a valid URL or if the connection to the endpoint + /// is not possible. + #[cfg(feature = "api-management-v1")] + pub async fn build_management_client( + self, + ) -> Result, Box> { + let channel = crate::api::clients::get_channel(&self.api_endpoint).await?; + Ok(ManagementServiceClient::new(channel)) + } + + /// Create a new [`OidcServiceClient`]. + /// + /// ### Errors + /// + /// This function returns a [`ClientError`] if the provided API endpoint + /// cannot be parsed into a valid URL or if the connection to the endpoint + /// is not possible. + #[cfg(feature = "api-oidc-v2")] + pub async fn build_oidc_client(self) -> Result, Box> { + let channel = crate::api::clients::get_channel(&self.api_endpoint).await?; + Ok(OidcServiceClient::new(channel)) + } + + /// Create a new [`OrganizationServiceClient`]. + /// + /// ### Errors + /// + /// This function returns a [`ClientError`] if the provided API endpoint + /// cannot be parsed into a valid URL or if the connection to the endpoint + /// is not possible. + #[cfg(feature = "api-org-v2")] + pub async fn build_organization_client( + self, + ) -> Result, Box> { + let channel = crate::api::clients::get_channel(&self.api_endpoint).await?; + Ok(OrganizationServiceClient::new(channel)) + } + + /// Create a new [`SessionServiceClient`]. + /// + /// ### Errors + /// + /// This function returns a [`ClientError`] if the provided API endpoint + /// cannot be parsed into a valid URL or if the connection to the endpoint + /// is not possible. + #[cfg(feature = "api-session-v2")] + pub async fn build_session_client( + self, + ) -> Result, Box> { + let channel = crate::api::clients::get_channel(&self.api_endpoint).await?; + Ok(SessionServiceClient::new(channel)) + } + + /// Create a new [`SettingsServiceClient`]. + /// + /// ### Errors + /// + /// This function returns a [`ClientError`] if the provided API endpoint + /// cannot be parsed into a valid URL or if the connection to the endpoint + /// is not possible. + #[cfg(feature = "api-settings-v2")] + pub async fn build_settings_client( + self, + ) -> Result, Box> { + let channel = crate::api::clients::get_channel(&self.api_endpoint).await?; + Ok(SettingsServiceClient::new(channel)) + } + + /// Create a new [`SystemServiceClient`]. + /// + /// ### Errors + /// + /// This function returns a [`ClientError`] if the provided API endpoint + /// cannot be parsed into a valid URL or if the connection to the endpoint + /// is not possible. + #[cfg(feature = "api-system-v1")] + pub async fn build_system_client(self) -> Result, Box> { + let channel = crate::api::clients::get_channel(&self.api_endpoint).await?; + Ok(SystemServiceClient::new(channel)) + } + + /// Create a new [`UserServiceClient`]. + /// + /// ### Errors + /// + /// This function returns a [`ClientError`] if the provided API endpoint + /// cannot be parsed into a valid URL or if the connection to the endpoint + /// is not possible. + #[cfg(feature = "api-user-v2")] + pub async fn build_user_client(self) -> Result, Box> { + let channel = crate::api::clients::get_channel(&self.api_endpoint).await?; + Ok(UserServiceClient::new(channel)) + } +} + +#[cfg(feature = "interceptors")] +impl ClientBuilderWithInterceptor { /// Create a new [`AdminServiceClient`]. /// /// Depending on the configured authentication method, the client has @@ -145,13 +247,12 @@ impl ClientBuilder { /// is not possible. #[cfg(feature = "api-admin-v1")] pub async fn build_admin_client( - &self, - ) -> Result>, Box> - { + self, + ) -> Result>, Box> { let channel = get_channel(&self.api_endpoint).await?; Ok(AdminServiceClient::with_interceptor( channel, - self.get_chained_interceptor(), + self.interceptor, )) } @@ -167,13 +268,12 @@ impl ClientBuilder { /// is not possible. #[cfg(feature = "api-auth-v1")] pub async fn build_auth_client( - &self, - ) -> Result>, Box> - { + self, + ) -> Result>, Box> { let channel = get_channel(&self.api_endpoint).await?; Ok(AuthServiceClient::with_interceptor( channel, - self.get_chained_interceptor(), + self.interceptor, )) } @@ -189,15 +289,12 @@ impl ClientBuilder { /// is not possible. #[cfg(feature = "api-management-v1")] pub async fn build_management_client( - &self, - ) -> Result< - ManagementServiceClient>, - Box, - > { + self, + ) -> Result>, Box> { let channel = get_channel(&self.api_endpoint).await?; Ok(ManagementServiceClient::with_interceptor( channel, - self.get_chained_interceptor(), + self.interceptor, )) } @@ -213,13 +310,12 @@ impl ClientBuilder { /// is not possible. #[cfg(feature = "api-oidc-v2")] pub async fn build_oidc_client( - &self, - ) -> Result>, Box> - { + self, + ) -> Result>, Box> { let channel = get_channel(&self.api_endpoint).await?; Ok(OidcServiceClient::with_interceptor( channel, - self.get_chained_interceptor(), + self.interceptor, )) } @@ -235,15 +331,12 @@ impl ClientBuilder { /// is not possible. #[cfg(feature = "api-org-v2")] pub async fn build_organization_client( - &self, - ) -> Result< - OrganizationServiceClient>, - Box, - > { + self, + ) -> Result>, Box> { let channel = get_channel(&self.api_endpoint).await?; Ok(OrganizationServiceClient::with_interceptor( channel, - self.get_chained_interceptor(), + self.interceptor, )) } @@ -259,13 +352,12 @@ impl ClientBuilder { /// is not possible. #[cfg(feature = "api-session-v2")] pub async fn build_session_client( - &self, - ) -> Result>, Box> - { + self, + ) -> Result>, Box> { let channel = get_channel(&self.api_endpoint).await?; Ok(SessionServiceClient::with_interceptor( channel, - self.get_chained_interceptor(), + self.interceptor, )) } @@ -281,15 +373,12 @@ impl ClientBuilder { /// is not possible. #[cfg(feature = "api-settings-v2")] pub async fn build_settings_client( - &self, - ) -> Result< - SettingsServiceClient>, - Box, - > { + self, + ) -> Result>, Box> { let channel = get_channel(&self.api_endpoint).await?; Ok(SettingsServiceClient::with_interceptor( channel, - self.get_chained_interceptor(), + self.interceptor, )) } @@ -305,13 +394,12 @@ impl ClientBuilder { /// is not possible. #[cfg(feature = "api-system-v1")] pub async fn build_system_client( - &self, - ) -> Result>, Box> - { + self, + ) -> Result>, Box> { let channel = get_channel(&self.api_endpoint).await?; Ok(SystemServiceClient::with_interceptor( channel, - self.get_chained_interceptor(), + self.interceptor, )) } @@ -327,38 +415,14 @@ impl ClientBuilder { /// is not possible. #[cfg(feature = "api-user-v2")] pub async fn build_user_client( - &self, - ) -> Result>, Box> - { + self, + ) -> Result>, Box> { let channel = get_channel(&self.api_endpoint).await?; Ok(UserServiceClient::with_interceptor( channel, - self.get_chained_interceptor(), + self.interceptor, )) } - - fn get_chained_interceptor(&self) -> ChainedInterceptor { - #[allow(unused_mut)] - let mut interceptor = ChainedInterceptor::new(); - #[cfg(feature = "interceptors")] - match &self.auth_type { - AuthType::AccessToken(token) => { - interceptor = - interceptor.add_interceptor(Box::new(AccessTokenInterceptor::new(token))); - } - AuthType::ServiceAccount(service_account, auth_options) => { - interceptor = - interceptor.add_interceptor(Box::new(ServiceAccountInterceptor::new( - &self.api_endpoint, - service_account, - auth_options.clone(), - ))); - } - _ => {} - } - - interceptor - } } async fn get_channel(api_endpoint: &str) -> Result { @@ -378,6 +442,7 @@ async fn get_channel(api_endpoint: &str) -> Result { #[cfg(test)] mod tests { use super::*; + use tonic::Request; const ZITADEL_URL: &str = "https://zitadel-libraries-l8boqa.zitadel.cloud"; const SERVICE_ACCOUNT: &str = r#" @@ -388,23 +453,11 @@ mod tests { "userId": "181828061098934529" }"#; - #[test] - fn client_builder_without_auth_passes_requests() { - let mut interceptor = ClientBuilder::new(ZITADEL_URL).get_chained_interceptor(); - let request = Request::new(()); - - assert!(request.metadata().is_empty()); - - let request = interceptor.call(request).unwrap(); - - assert!(request.metadata().is_empty()); - } - #[test] fn client_builder_with_access_token_attaches_it() { let mut interceptor = ClientBuilder::new(ZITADEL_URL) .with_access_token("token") - .get_chained_interceptor(); + .interceptor; let request = Request::new(()); assert!(request.metadata().is_empty()); @@ -423,7 +476,7 @@ mod tests { let sa = ServiceAccount::load_from_json(SERVICE_ACCOUNT).unwrap(); let mut interceptor = ClientBuilder::new(ZITADEL_URL) .with_service_account(&sa, None) - .get_chained_interceptor(); + .interceptor; let request = Request::new(()); assert!(request.metadata().is_empty());